Skip to content

Commit 3c8c291

Browse files
authored
[NVPTX] Improve 64bit FSH/ROT lowering when shift amount is constant (#131371)
When the sift amount of a 64-bit funnel-shift or rotate is constant, it may be decomposed into two 32-bit funnel-sifts. This ensures that we recover any possible performance losses associated with the correctness fix in a131fbf. In order to efficiently represent the expansion with Selection DAG nodes, NVPTXISD::BUILD_VECTOR and NVPTXISD::UNPACK_VECTOR are added which allow the vector output/input to be represented as a scalar. In the future, if we add support for the v2i32 type to the NVPTX backend these nodes may be removed.
1 parent 6cc23fa commit 3c8c291

File tree

7 files changed

+449
-47
lines changed

7 files changed

+449
-47
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
113113
if (tryFence(N))
114114
return;
115115
break;
116+
case NVPTXISD::UNPACK_VECTOR:
117+
tryUNPACK_VECTOR(N);
118+
return;
116119
case ISD::EXTRACT_VECTOR_ELT:
117120
if (tryEXTRACT_VECTOR_ELEMENT(N))
118121
return;
@@ -445,6 +448,17 @@ bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
445448
return true;
446449
}
447450

451+
bool NVPTXDAGToDAGISel::tryUNPACK_VECTOR(SDNode *N) {
452+
SDValue Vector = N->getOperand(0);
453+
MVT EltVT = N->getSimpleValueType(0);
454+
455+
MachineSDNode *N2 =
456+
CurDAG->getMachineNode(NVPTX::I64toV2I32, SDLoc(N), EltVT, EltVT, Vector);
457+
458+
ReplaceNode(N, N2);
459+
return true;
460+
}
461+
448462
// Find all instances of extract_vector_elt that use this v2f16 vector
449463
// and coalesce them into a scattering move instruction.
450464
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
8888
bool tryConstantFP(SDNode *N);
8989
bool SelectSETP_F16X2(SDNode *N);
9090
bool SelectSETP_BF16X2(SDNode *N);
91+
bool tryUNPACK_VECTOR(SDNode *N);
9192
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
9293
void SelectV2I64toI128(SDNode *N);
9394
void SelectI128toV2I64(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
#include <iterator>
6767
#include <optional>
6868
#include <string>
69+
#include <tuple>
6970
#include <utility>
7071
#include <vector>
7172

@@ -668,8 +669,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
668669
{MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64},
669670
Expand);
670671

671-
if (STI.hasHWROT32())
672+
if (STI.hasHWROT32()) {
672673
setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal);
674+
setOperationAction({ISD::ROTL, ISD::ROTR, ISD::FSHL, ISD::FSHR}, MVT::i64,
675+
Custom);
676+
}
673677

674678
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
675679

@@ -1056,6 +1060,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10561060
MAKE_CASE(NVPTXISD::StoreRetvalV2)
10571061
MAKE_CASE(NVPTXISD::StoreRetvalV4)
10581062
MAKE_CASE(NVPTXISD::PseudoUseParam)
1063+
MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
1064+
MAKE_CASE(NVPTXISD::BUILD_VECTOR)
10591065
MAKE_CASE(NVPTXISD::RETURN)
10601066
MAKE_CASE(NVPTXISD::CallSeqBegin)
10611067
MAKE_CASE(NVPTXISD::CallSeqEnd)
@@ -2758,6 +2764,61 @@ static SDValue lowerCTLZCTPOP(SDValue Op, SelectionDAG &DAG) {
27582764
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CT, SDNodeFlags::NonNeg);
27592765
}
27602766

2767+
static SDValue expandFSH64(SDValue A, SDValue B, SDValue ShiftAmount, SDLoc DL,
2768+
unsigned Opcode, SelectionDAG &DAG) {
2769+
assert(A.getValueType() == MVT::i64 && B.getValueType() == MVT::i64);
2770+
2771+
const auto *AmtConst = dyn_cast<ConstantSDNode>(ShiftAmount);
2772+
if (!AmtConst)
2773+
return SDValue();
2774+
const auto Amt = AmtConst->getZExtValue() & 63;
2775+
2776+
SDValue UnpackA =
2777+
DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, A);
2778+
SDValue UnpackB =
2779+
DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, B);
2780+
2781+
// Arch is Little endiain: 0 = low bits, 1 = high bits
2782+
SDValue ALo = UnpackA.getValue(0);
2783+
SDValue AHi = UnpackA.getValue(1);
2784+
SDValue BLo = UnpackB.getValue(0);
2785+
SDValue BHi = UnpackB.getValue(1);
2786+
2787+
// The bitfeild consists of { AHi : ALo : BHi : BLo }
2788+
//
2789+
// * FSHL, Amt < 32 - The window will contain { AHi : ALo : BHi }
2790+
// * FSHL, Amt >= 32 - The window will contain { ALo : BHi : BLo }
2791+
// * FSHR, Amt < 32 - The window will contain { ALo : BHi : BLo }
2792+
// * FSHR, Amt >= 32 - The window will contain { AHi : ALo : BHi }
2793+
//
2794+
// Note that Amt = 0 and Amt = 32 are special cases where 32-bit funnel shifts
2795+
// are not needed at all. Amt = 0 is a no-op producing either A or B depending
2796+
// on the direction. Amt = 32 can be implemented by a packing and unpacking
2797+
// move to select and arrange the 32bit values. For simplicity, these cases
2798+
// are not handled here explicitly and instead we rely on DAGCombiner to
2799+
// remove the no-op funnel shifts we insert.
2800+
auto [High, Mid, Low] = ((Opcode == ISD::FSHL) == (Amt < 32))
2801+
? std::make_tuple(AHi, ALo, BHi)
2802+
: std::make_tuple(ALo, BHi, BLo);
2803+
2804+
SDValue NewAmt = DAG.getConstant(Amt & 31, DL, MVT::i32);
2805+
SDValue RHi = DAG.getNode(Opcode, DL, MVT::i32, {High, Mid, NewAmt});
2806+
SDValue RLo = DAG.getNode(Opcode, DL, MVT::i32, {Mid, Low, NewAmt});
2807+
2808+
return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64, {RLo, RHi});
2809+
}
2810+
2811+
static SDValue lowerFSH(SDValue Op, SelectionDAG &DAG) {
2812+
return expandFSH64(Op->getOperand(0), Op->getOperand(1), Op->getOperand(2),
2813+
SDLoc(Op), Op->getOpcode(), DAG);
2814+
}
2815+
2816+
static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
2817+
unsigned Opcode = Op->getOpcode() == ISD::ROTL ? ISD::FSHL : ISD::FSHR;
2818+
return expandFSH64(Op->getOperand(0), Op->getOperand(0), Op->getOperand(1),
2819+
SDLoc(Op), Opcode, DAG);
2820+
}
2821+
27612822
SDValue
27622823
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
27632824
switch (Op.getOpcode()) {
@@ -2818,6 +2879,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28182879
return LowerVAARG(Op, DAG);
28192880
case ISD::VASTART:
28202881
return LowerVASTART(Op, DAG);
2882+
case ISD::FSHL:
2883+
case ISD::FSHR:
2884+
return lowerFSH(Op, DAG);
2885+
case ISD::ROTL:
2886+
case ISD::ROTR:
2887+
return lowerROT(Op, DAG);
28212888
case ISD::ABS:
28222889
case ISD::SMIN:
28232890
case ISD::SMAX:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ enum NodeType : unsigned {
6161
BFE,
6262
BFI,
6363
PRMT,
64+
65+
/// This node is similar to ISD::BUILD_VECTOR except that the output may be
66+
/// implicitly bitcast to a scalar. This allows for the representation of
67+
/// packing move instructions for vector types which are not legal i.e. v2i32
68+
BUILD_VECTOR,
69+
70+
/// This node is the inverse of NVPTX::BUILD_VECTOR. It takes a single value
71+
/// which may be a scalar and unpacks it into multiple values by implicitly
72+
/// converting it to a vector.
73+
UNPACK_VECTOR,
74+
6475
FCOPYSIGN,
6576
DYNAMIC_STACKALLOC,
6677
STACKRESTORE,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3222,6 +3222,12 @@ def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
32223222
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
32233223
(CVT_u32_u16 $a, CvtNONE)>;
32243224

3225+
3226+
def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;
3227+
3228+
def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),
3229+
(V2I32toI64 $a, $b)>;
3230+
32253231
//
32263232
// Funnel-Shift
32273233
//

0 commit comments

Comments
 (0)