Skip to content

Commit 98dff5e

Browse files
committed
[RISCV] Move SHFLI matching to DAG combine. Add 32-bit support for RV64
We previously used isel patterns for this, but that used quite a bit of space in the isel table due to OR being associative and commutative. It also wouldn't handle shifts/ands being in reversed order. This generalizes the shift/and matching from GREVI to take the expected mask table as input so we can reuse it for SHFLI. There is no SHFLIW instruction, but we can promote a 32-bit SHFLI to i64 on RV64. As long as bit 4 of the control bit isn't set, a 64-bit SHFLI will preserve 33 sign bits if the input had at least 33 sign bits. ComputeNumSignBits has been updated to account for that to avoid sext.w in the tests. Reviewed By: frasercrmck Differential Revision: https://reviews.llvm.org/D96661
1 parent 4ffad1f commit 98dff5e

File tree

4 files changed

+191
-178
lines changed

4 files changed

+191
-178
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 180 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,20 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
25452545
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
25462546
break;
25472547
}
2548+
case RISCVISD::SHFLI: {
2549+
// There is no SHFLIW instruction, but we can just promote the operation.
2550+
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
2551+
"Unexpected custom legalisation");
2552+
SDLoc DL(N);
2553+
SDValue NewOp0 =
2554+
DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
2555+
SDValue NewRes =
2556+
DAG.getNode(RISCVISD::SHFLI, DL, MVT::i64, NewOp0, N->getOperand(1));
2557+
// ReplaceNodeResults requires we maintain the same type for the return
2558+
// value.
2559+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes));
2560+
break;
2561+
}
25482562
case ISD::BSWAP:
25492563
case ISD::BITREVERSE: {
25502564
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
@@ -2674,19 +2688,21 @@ struct RISCVBitmanipPat {
26742688
}
26752689
};
26762690

2677-
// Matches any of the following bit-manipulation patterns:
2678-
// (and (shl x, 1), (0x55555555 << 1))
2679-
// (and (srl x, 1), 0x55555555)
2680-
// (shl (and x, 0x55555555), 1)
2681-
// (srl (and x, (0x55555555 << 1)), 1)
2682-
// where the shift amount and mask may vary thus:
2683-
// [1] = 0x55555555 / 0xAAAAAAAA
2684-
// [2] = 0x33333333 / 0xCCCCCCCC
2685-
// [4] = 0x0F0F0F0F / 0xF0F0F0F0
2686-
// [8] = 0x00FF00FF / 0xFF00FF00
2687-
// [16] = 0x0000FFFF / 0xFFFFFFFF
2688-
// [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
2689-
static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
2691+
// Matches patterns of the form
2692+
// (and (shl x, C2), (C1 << C2))
2693+
// (and (srl x, C2), C1)
2694+
// (shl (and x, C1), C2)
2695+
// (srl (and x, (C1 << C2)), C2)
2696+
// Where C2 is a power of 2 and C1 has at least that many leading zeroes.
2697+
// The expected masks for each shift amount are specified in BitmanipMasks where
2698+
// BitmanipMasks[log2(C2)] specifies the expected C1 value.
2699+
// The max allowed shift amount is either XLen/2 or XLen/4 determined by whether
2700+
// BitmanipMasks contains 6 or 5 entries assuming that the maximum possible
2701+
// XLen is 64.
2702+
static Optional<RISCVBitmanipPat>
2703+
matchRISCVBitmanipPat(SDValue Op, ArrayRef<uint64_t> BitmanipMasks) {
2704+
assert((BitmanipMasks.size() == 5 || BitmanipMasks.size() == 6) &&
2705+
"Unexpected number of masks");
26902706
Optional<uint64_t> Mask;
26912707
// Optionally consume a mask around the shift operation.
26922708
if (Op.getOpcode() == ISD::AND && isa<ConstantSDNode>(Op.getOperand(1))) {
@@ -2699,26 +2715,17 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
26992715

27002716
if (!isa<ConstantSDNode>(Op.getOperand(1)))
27012717
return None;
2702-
auto ShAmt = Op.getConstantOperandVal(1);
2718+
uint64_t ShAmt = Op.getConstantOperandVal(1);
27032719

2704-
if (!isPowerOf2_64(ShAmt))
2720+
unsigned Width = Op.getValueType() == MVT::i64 ? 64 : 32;
2721+
if (ShAmt >= Width && !isPowerOf2_64(ShAmt))
27052722
return None;
2706-
2707-
// These are the unshifted masks which we use to match bit-manipulation
2708-
// patterns. They may be shifted left in certain circumstances.
2709-
static const uint64_t BitmanipMasks[] = {
2710-
0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
2711-
0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL,
2712-
};
2713-
2714-
unsigned MaskIdx = Log2_64(ShAmt);
2715-
if (MaskIdx >= array_lengthof(BitmanipMasks))
2723+
// If we don't have enough masks for 64 bit, then we must be trying to
2724+
// match SHFL so we're only allowed to shift 1/4 of the width.
2725+
if (BitmanipMasks.size() == 5 && ShAmt >= (Width / 2))
27162726
return None;
27172727

2718-
auto Src = Op.getOperand(0);
2719-
2720-
unsigned Width = Op.getValueType() == MVT::i64 ? 64 : 32;
2721-
auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t>(Width);
2728+
SDValue Src = Op.getOperand(0);
27222729

27232730
// The expected mask is shifted left when the AND is found around SHL
27242731
// patterns.
@@ -2745,6 +2752,9 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
27452752
}
27462753
}
27472754

2755+
unsigned MaskIdx = Log2_32(ShAmt);
2756+
uint64_t ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t>(Width);
2757+
27482758
if (SHLExpMask)
27492759
ExpMask <<= ShAmt;
27502760

@@ -2754,15 +2764,38 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
27542764
return RISCVBitmanipPat{Src, (unsigned)ShAmt, IsSHL};
27552765
}
27562766

2767+
// Matches any of the following bit-manipulation patterns:
2768+
// (and (shl x, 1), (0x55555555 << 1))
2769+
// (and (srl x, 1), 0x55555555)
2770+
// (shl (and x, 0x55555555), 1)
2771+
// (srl (and x, (0x55555555 << 1)), 1)
2772+
// where the shift amount and mask may vary thus:
2773+
// [1] = 0x55555555 / 0xAAAAAAAA
2774+
// [2] = 0x33333333 / 0xCCCCCCCC
2775+
// [4] = 0x0F0F0F0F / 0xF0F0F0F0
2776+
// [8] = 0x00FF00FF / 0xFF00FF00
2777+
// [16] = 0x0000FFFF / 0xFFFFFFFF
2778+
// [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
2779+
static Optional<RISCVBitmanipPat> matchGREVIPat(SDValue Op) {
2780+
// These are the unshifted masks which we use to match bit-manipulation
2781+
// patterns. They may be shifted left in certain circumstances.
2782+
static const uint64_t BitmanipMasks[] = {
2783+
0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL,
2784+
0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL};
2785+
2786+
return matchRISCVBitmanipPat(Op, BitmanipMasks);
2787+
}
2788+
27572789
// Match the following pattern as a GREVI(W) operation
27582790
// (or (BITMANIP_SHL x), (BITMANIP_SRL x))
27592791
static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG,
27602792
const RISCVSubtarget &Subtarget) {
2793+
assert(Subtarget.hasStdExtZbp() && "Expected Zbp extenson");
27612794
EVT VT = Op.getValueType();
27622795

27632796
if (VT == Subtarget.getXLenVT() || (Subtarget.is64Bit() && VT == MVT::i32)) {
2764-
auto LHS = matchRISCVBitmanipPat(Op.getOperand(0));
2765-
auto RHS = matchRISCVBitmanipPat(Op.getOperand(1));
2797+
auto LHS = matchGREVIPat(Op.getOperand(0));
2798+
auto RHS = matchGREVIPat(Op.getOperand(1));
27662799
if (LHS && RHS && LHS->formsPairWith(*RHS)) {
27672800
SDLoc DL(Op);
27682801
return DAG.getNode(
@@ -2784,6 +2817,7 @@ static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG,
27842817
// 4. (or (rotl/rotr x, bitwidth/2), x)
27852818
static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
27862819
const RISCVSubtarget &Subtarget) {
2820+
assert(Subtarget.hasStdExtZbp() && "Expected Zbp extenson");
27872821
EVT VT = Op.getValueType();
27882822

27892823
if (VT == Subtarget.getXLenVT() || (Subtarget.is64Bit() && VT == MVT::i32)) {
@@ -2822,14 +2856,14 @@ static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
28222856
return SDValue();
28232857
SDValue OrOp0 = Op0.getOperand(0);
28242858
SDValue OrOp1 = Op0.getOperand(1);
2825-
auto LHS = matchRISCVBitmanipPat(OrOp0);
2859+
auto LHS = matchGREVIPat(OrOp0);
28262860
// OR is commutable so swap the operands and try again: x might have been
28272861
// on the left
28282862
if (!LHS) {
28292863
std::swap(OrOp0, OrOp1);
2830-
LHS = matchRISCVBitmanipPat(OrOp0);
2864+
LHS = matchGREVIPat(OrOp0);
28312865
}
2832-
auto RHS = matchRISCVBitmanipPat(Op1);
2866+
auto RHS = matchGREVIPat(Op1);
28332867
if (LHS && RHS && LHS->formsPairWith(*RHS) && LHS->Op == OrOp1) {
28342868
return DAG.getNode(
28352869
RISCVISD::GORCI, DL, VT, LHS->Op,
@@ -2839,6 +2873,102 @@ static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
28392873
return SDValue();
28402874
}
28412875

2876+
// Matches any of the following bit-manipulation patterns:
2877+
// (and (shl x, 1), (0x22222222 << 1))
2878+
// (and (srl x, 1), 0x22222222)
2879+
// (shl (and x, 0x22222222), 1)
2880+
// (srl (and x, (0x22222222 << 1)), 1)
2881+
// where the shift amount and mask may vary thus:
2882+
// [1] = 0x22222222 / 0x44444444
2883+
// [2] = 0x0C0C0C0C / 0x3C3C3C3C
2884+
// [4] = 0x00F000F0 / 0x0F000F00
2885+
// [8] = 0x0000FF00 / 0x00FF0000
2886+
// [16] = 0x00000000FFFF0000 / 0x0000FFFF00000000 (for RV64)
2887+
static Optional<RISCVBitmanipPat> matchSHFLPat(SDValue Op) {
2888+
// These are the unshifted masks which we use to match bit-manipulation
2889+
// patterns. They may be shifted left in certain circumstances.
2890+
static const uint64_t BitmanipMasks[] = {
2891+
0x2222222222222222ULL, 0x0C0C0C0C0C0C0C0CULL, 0x00F000F000F000F0ULL,
2892+
0x0000FF000000FF00ULL, 0x00000000FFFF0000ULL};
2893+
2894+
return matchRISCVBitmanipPat(Op, BitmanipMasks);
2895+
}
2896+
2897+
// Match (or (or (SHFL_SHL x), (SHFL_SHR x)), (SHFL_AND x)
2898+
static SDValue combineORToSHFL(SDValue Op, SelectionDAG &DAG,
2899+
const RISCVSubtarget &Subtarget) {
2900+
assert(Subtarget.hasStdExtZbp() && "Expected Zbp extenson");
2901+
EVT VT = Op.getValueType();
2902+
2903+
if (VT != MVT::i32 && VT != Subtarget.getXLenVT())
2904+
return SDValue();
2905+
2906+
SDValue Op0 = Op.getOperand(0);
2907+
SDValue Op1 = Op.getOperand(1);
2908+
2909+
// Or is commutable so canonicalize the second OR to the LHS.
2910+
if (Op0.getOpcode() != ISD::OR)
2911+
std::swap(Op0, Op1);
2912+
if (Op0.getOpcode() != ISD::OR)
2913+
return SDValue();
2914+
2915+
// We found an inner OR, so our operands are the operands of the inner OR
2916+
// and the other operand of the outer OR.
2917+
SDValue A = Op0.getOperand(0);
2918+
SDValue B = Op0.getOperand(1);
2919+
SDValue C = Op1;
2920+
2921+
auto Match1 = matchSHFLPat(A);
2922+
auto Match2 = matchSHFLPat(B);
2923+
2924+
// If neither matched, we failed.
2925+
if (!Match1 && !Match2)
2926+
return SDValue();
2927+
2928+
// We had at least one match. if one failed, try the remaining C operand.
2929+
if (!Match1) {
2930+
std::swap(A, C);
2931+
Match1 = matchSHFLPat(A);
2932+
if (!Match1)
2933+
return SDValue();
2934+
} else if (!Match2) {
2935+
std::swap(B, C);
2936+
Match2 = matchSHFLPat(B);
2937+
if (!Match2)
2938+
return SDValue();
2939+
}
2940+
assert(Match1 && Match2);
2941+
2942+
// Make sure our matches pair up.
2943+
if (!Match1->formsPairWith(*Match2))
2944+
return SDValue();
2945+
2946+
// All the remains is to make sure C is an AND with the same input, that masks
2947+
// out the bits that are being shuffled.
2948+
if (C.getOpcode() != ISD::AND || !isa<ConstantSDNode>(C.getOperand(1)) ||
2949+
C.getOperand(0) != Match1->Op)
2950+
return SDValue();
2951+
2952+
uint64_t Mask = C.getConstantOperandVal(1);
2953+
2954+
static const uint64_t BitmanipMasks[] = {
2955+
0x9999999999999999ULL, 0xC3C3C3C3C3C3C3C3ULL, 0xF00FF00FF00FF00FULL,
2956+
0xFF0000FFFF0000FFULL, 0xFFFF00000000FFFFULL,
2957+
};
2958+
2959+
unsigned Width = Op.getValueType() == MVT::i64 ? 64 : 32;
2960+
unsigned MaskIdx = Log2_32(Match1->ShAmt);
2961+
uint64_t ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t>(Width);
2962+
2963+
if (Mask != ExpMask)
2964+
return SDValue();
2965+
2966+
SDLoc DL(Op);
2967+
return DAG.getNode(
2968+
RISCVISD::SHFLI, DL, VT, Match1->Op,
2969+
DAG.getTargetConstant(Match1->ShAmt, DL, Subtarget.getXLenVT()));
2970+
}
2971+
28422972
// Combine (GREVI (GREVI x, C2), C1) -> (GREVI x, C1^C2) when C1^C2 is
28432973
// non-zero, and to x when it is. Any repeated GREVI stage undoes itself.
28442974
// Combine (GORCI (GORCI x, C2), C1) -> (GORCI x, C1|C2). Repeated stage does
@@ -3018,6 +3148,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
30183148
return GREV;
30193149
if (auto GORC = combineORToGORC(SDValue(N, 0), DCI.DAG, Subtarget))
30203150
return GORC;
3151+
if (auto SHFL = combineORToSHFL(SDValue(N, 0), DCI.DAG, Subtarget))
3152+
return SHFL;
30213153
break;
30223154
case RISCVISD::SELECT_CC: {
30233155
// Transform
@@ -3265,6 +3397,19 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
32653397
// more precise answer could be calculated for SRAW depending on known
32663398
// bits in the shift amount.
32673399
return 33;
3400+
case RISCVISD::SHFLI: {
3401+
// There is no SHFLIW, but a i64 SHFLI with bit 4 of the control word
3402+
// cleared doesn't affect bit 31. The upper 32 bits will be shuffled, but
3403+
// will stay within the upper 32 bits. If there were more than 32 sign bits
3404+
// before there will be at least 33 sign bits after.
3405+
if (Op.getValueType() == MVT::i64 &&
3406+
(Op.getConstantOperandVal(1) & 0x10) == 0) {
3407+
unsigned Tmp = DAG.ComputeNumSignBits(Op.getOperand(0), Depth + 1);
3408+
if (Tmp > 32)
3409+
return 33;
3410+
}
3411+
break;
3412+
}
32683413
case RISCVISD::VMV_X_S:
32693414
// The number of sign bits of the scalar result is computed by obtaining the
32703415
// element type of the input vector operand, subtracting its width from the
@@ -4928,6 +5073,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
49285073
NODE_NAME_CASE(GREVIW)
49295074
NODE_NAME_CASE(GORCI)
49305075
NODE_NAME_CASE(GORCIW)
5076+
NODE_NAME_CASE(SHFLI)
49315077
NODE_NAME_CASE(VMV_V_X_VL)
49325078
NODE_NAME_CASE(VFMV_V_F_VL)
49335079
NODE_NAME_CASE(VMV_X_S)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ enum NodeType : unsigned {
8888
GREVIW,
8989
GORCI,
9090
GORCIW,
91+
SHFLI,
9192
// Vector Extension
9293
// VMV_V_X_VL matches the semantics of vmv.v.x but includes an extra operand
9394
// for the VL value to be used for the operation.

llvm/lib/Target/RISCV/RISCVInstrInfoB.td

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -671,8 +671,10 @@ def riscv_grevi : SDNode<"RISCVISD::GREVI", SDTIntBinOp, []>;
671671
def riscv_greviw : SDNode<"RISCVISD::GREVIW", SDTIntBinOp, []>;
672672
def riscv_gorci : SDNode<"RISCVISD::GORCI", SDTIntBinOp, []>;
673673
def riscv_gorciw : SDNode<"RISCVISD::GORCIW", SDTIntBinOp, []>;
674+
def riscv_shfli : SDNode<"RISCVISD::SHFLI", SDTIntBinOp, []>;
674675

675676
let Predicates = [HasStdExtZbp] in {
677+
def : Pat<(riscv_shfli GPR:$rs1, timm:$shamt), (SHFLI GPR:$rs1, timm:$shamt)>;
676678
def : Pat<(riscv_grevi GPR:$rs1, timm:$shamt), (GREVI GPR:$rs1, timm:$shamt)>;
677679
def : Pat<(riscv_gorci GPR:$rs1, timm:$shamt), (GORCI GPR:$rs1, timm:$shamt)>;
678680

@@ -789,48 +791,6 @@ let Predicates = [HasStdExtZbbOrZbp, IsRV64] in {
789791
def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXTH_RV64 GPR:$rs)>;
790792
}
791793

792-
let Predicates = [HasStdExtZbp, IsRV32] in {
793-
def : Pat<(or (or (and (shl GPR:$rs1, (i32 8)), (i32 0x00FF0000)),
794-
(and GPR:$rs1, (i32 0xFF0000FF))),
795-
(and (srl GPR:$rs1, (i32 8)), (i32 0x0000FF00))),
796-
(SHFLI GPR:$rs1, (i32 8))>;
797-
def : Pat<(or (or (and (shl GPR:$rs1, (i32 4)), (i32 0x0F000F00)),
798-
(and GPR:$rs1, (i32 0xF00FF00F))),
799-
(and (srl GPR:$rs1, (i32 4)), (i32 0x00F000F0))),
800-
(SHFLI GPR:$rs1, (i32 4))>;
801-
def : Pat<(or (or (and (shl GPR:$rs1, (i32 2)), (i32 0x30303030)),
802-
(and GPR:$rs1, (i32 0xC3C3C3C3))),
803-
(and (srl GPR:$rs1, (i32 2)), (i32 0x0C0C0C0C))),
804-
(SHFLI GPR:$rs1, (i32 2))>;
805-
def : Pat<(or (or (and (shl GPR:$rs1, (i32 1)), (i32 0x44444444)),
806-
(and GPR:$rs1, (i32 0x99999999))),
807-
(and (srl GPR:$rs1, (i32 1)), (i32 0x22222222))),
808-
(SHFLI GPR:$rs1, (i32 1))>;
809-
} // Predicates = [HasStdExtZbp, IsRV32]
810-
811-
let Predicates = [HasStdExtZbp, IsRV64] in {
812-
def : Pat<(or (or (and (shl GPR:$rs1, (i64 16)), (i64 0x0000FFFF00000000)),
813-
(and GPR:$rs1, (i64 0xFFFF00000000FFFF))),
814-
(and (srl GPR:$rs1, (i64 16)), (i64 0x00000000FFFF0000))),
815-
(SHFLI GPR:$rs1, (i64 16))>;
816-
def : Pat<(or (or (and (shl GPR:$rs1, (i64 8)), (i64 0x00FF000000FF0000)),
817-
(and GPR:$rs1, (i64 0xFF0000FFFF0000FF))),
818-
(and (srl GPR:$rs1, (i64 8)), (i64 0x0000FF000000FF00))),
819-
(SHFLI GPR:$rs1, (i64 8))>;
820-
def : Pat<(or (or (and (shl GPR:$rs1, (i64 4)), (i64 0x0F000F000F000F00)),
821-
(and GPR:$rs1, (i64 0xF00FF00FF00FF00F))),
822-
(and (srl GPR:$rs1, (i64 4)), (i64 0x00F000F000F000F0))),
823-
(SHFLI GPR:$rs1, (i64 4))>;
824-
def : Pat<(or (or (and (shl GPR:$rs1, (i64 2)), (i64 0x3030303030303030)),
825-
(and GPR:$rs1, (i64 0xC3C3C3C3C3C3C3C3))),
826-
(and (srl GPR:$rs1, (i64 2)), (i64 0x0C0C0C0C0C0C0C0C))),
827-
(SHFLI GPR:$rs1, (i64 2))>;
828-
def : Pat<(or (or (and (shl GPR:$rs1, (i64 1)), (i64 0x4444444444444444)),
829-
(and GPR:$rs1, (i64 0x9999999999999999))),
830-
(and (srl GPR:$rs1, (i64 1)), (i64 0x2222222222222222))),
831-
(SHFLI GPR:$rs1, (i64 1))>;
832-
} // Predicates = [HasStdExtZbp, IsRV64]
833-
834794
let Predicates = [HasStdExtZba] in {
835795
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), GPR:$rs2),
836796
(SH1ADD GPR:$rs1, GPR:$rs2)>;

0 commit comments

Comments
 (0)