@@ -2545,6 +2545,20 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
2545
2545
Results.push_back (DAG.getNode (ISD::TRUNCATE, DL, MVT::i32 , NewRes));
2546
2546
break ;
2547
2547
}
2548
+ case RISCVISD::SHFLI: {
2549
+ // There is no SHFLIW instruction, but we can just promote the operation.
2550
+ assert (N->getValueType (0 ) == MVT::i32 && Subtarget.is64Bit () &&
2551
+ " Unexpected custom legalisation" );
2552
+ SDLoc DL (N);
2553
+ SDValue NewOp0 =
2554
+ DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i64 , N->getOperand (0 ));
2555
+ SDValue NewRes =
2556
+ DAG.getNode (RISCVISD::SHFLI, DL, MVT::i64 , NewOp0, N->getOperand (1 ));
2557
+ // ReplaceNodeResults requires we maintain the same type for the return
2558
+ // value.
2559
+ Results.push_back (DAG.getNode (ISD::TRUNCATE, DL, MVT::i32 , NewRes));
2560
+ break ;
2561
+ }
2548
2562
case ISD::BSWAP:
2549
2563
case ISD::BITREVERSE: {
2550
2564
assert (N->getValueType (0 ) == MVT::i32 && Subtarget.is64Bit () &&
@@ -2674,19 +2688,21 @@ struct RISCVBitmanipPat {
2674
2688
}
2675
2689
};
2676
2690
2677
- // Matches any of the following bit-manipulation patterns:
2678
- // (and (shl x, 1), (0x55555555 << 1))
2679
- // (and (srl x, 1), 0x55555555)
2680
- // (shl (and x, 0x55555555), 1)
2681
- // (srl (and x, (0x55555555 << 1)), 1)
2682
- // where the shift amount and mask may vary thus:
2683
- // [1] = 0x55555555 / 0xAAAAAAAA
2684
- // [2] = 0x33333333 / 0xCCCCCCCC
2685
- // [4] = 0x0F0F0F0F / 0xF0F0F0F0
2686
- // [8] = 0x00FF00FF / 0xFF00FF00
2687
- // [16] = 0x0000FFFF / 0xFFFFFFFF
2688
- // [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
2689
- static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat (SDValue Op) {
2691
+ // Matches patterns of the form
2692
+ // (and (shl x, C2), (C1 << C2))
2693
+ // (and (srl x, C2), C1)
2694
+ // (shl (and x, C1), C2)
2695
+ // (srl (and x, (C1 << C2)), C2)
2696
+ // Where C2 is a power of 2 and C1 has at least that many leading zeroes.
2697
+ // The expected masks for each shift amount are specified in BitmanipMasks where
2698
+ // BitmanipMasks[log2(C2)] specifies the expected C1 value.
2699
+ // The max allowed shift amount is either XLen/2 or XLen/4 determined by whether
2700
+ // BitmanipMasks contains 6 or 5 entries assuming that the maximum possible
2701
+ // XLen is 64.
2702
+ static Optional<RISCVBitmanipPat>
2703
+ matchRISCVBitmanipPat (SDValue Op, ArrayRef<uint64_t > BitmanipMasks) {
2704
+ assert ((BitmanipMasks.size () == 5 || BitmanipMasks.size () == 6 ) &&
2705
+ " Unexpected number of masks" );
2690
2706
Optional<uint64_t > Mask;
2691
2707
// Optionally consume a mask around the shift operation.
2692
2708
if (Op.getOpcode () == ISD::AND && isa<ConstantSDNode>(Op.getOperand (1 ))) {
@@ -2699,26 +2715,17 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
2699
2715
2700
2716
if (!isa<ConstantSDNode>(Op.getOperand (1 )))
2701
2717
return None;
2702
- auto ShAmt = Op.getConstantOperandVal (1 );
2718
+ uint64_t ShAmt = Op.getConstantOperandVal (1 );
2703
2719
2704
- if (!isPowerOf2_64 (ShAmt))
2720
+ unsigned Width = Op.getValueType () == MVT::i64 ? 64 : 32 ;
2721
+ if (ShAmt >= Width && !isPowerOf2_64 (ShAmt))
2705
2722
return None;
2706
-
2707
- // These are the unshifted masks which we use to match bit-manipulation
2708
- // patterns. They may be shifted left in certain circumstances.
2709
- static const uint64_t BitmanipMasks[] = {
2710
- 0x5555555555555555ULL , 0x3333333333333333ULL , 0x0F0F0F0F0F0F0F0FULL ,
2711
- 0x00FF00FF00FF00FFULL , 0x0000FFFF0000FFFFULL , 0x00000000FFFFFFFFULL ,
2712
- };
2713
-
2714
- unsigned MaskIdx = Log2_64 (ShAmt);
2715
- if (MaskIdx >= array_lengthof (BitmanipMasks))
2723
+ // If we don't have enough masks for 64 bit, then we must be trying to
2724
+ // match SHFL so we're only allowed to shift 1/4 of the width.
2725
+ if (BitmanipMasks.size () == 5 && ShAmt >= (Width / 2 ))
2716
2726
return None;
2717
2727
2718
- auto Src = Op.getOperand (0 );
2719
-
2720
- unsigned Width = Op.getValueType () == MVT::i64 ? 64 : 32 ;
2721
- auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t >(Width);
2728
+ SDValue Src = Op.getOperand (0 );
2722
2729
2723
2730
// The expected mask is shifted left when the AND is found around SHL
2724
2731
// patterns.
@@ -2745,6 +2752,9 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
2745
2752
}
2746
2753
}
2747
2754
2755
+ unsigned MaskIdx = Log2_32 (ShAmt);
2756
+ uint64_t ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t >(Width);
2757
+
2748
2758
if (SHLExpMask)
2749
2759
ExpMask <<= ShAmt;
2750
2760
@@ -2754,15 +2764,38 @@ static Optional<RISCVBitmanipPat> matchRISCVBitmanipPat(SDValue Op) {
2754
2764
return RISCVBitmanipPat{Src, (unsigned )ShAmt, IsSHL};
2755
2765
}
2756
2766
2767
+ // Matches any of the following bit-manipulation patterns:
2768
+ // (and (shl x, 1), (0x55555555 << 1))
2769
+ // (and (srl x, 1), 0x55555555)
2770
+ // (shl (and x, 0x55555555), 1)
2771
+ // (srl (and x, (0x55555555 << 1)), 1)
2772
+ // where the shift amount and mask may vary thus:
2773
+ // [1] = 0x55555555 / 0xAAAAAAAA
2774
+ // [2] = 0x33333333 / 0xCCCCCCCC
2775
+ // [4] = 0x0F0F0F0F / 0xF0F0F0F0
2776
+ // [8] = 0x00FF00FF / 0xFF00FF00
2777
+ // [16] = 0x0000FFFF / 0xFFFFFFFF
2778
+ // [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64)
2779
+ static Optional<RISCVBitmanipPat> matchGREVIPat (SDValue Op) {
2780
+ // These are the unshifted masks which we use to match bit-manipulation
2781
+ // patterns. They may be shifted left in certain circumstances.
2782
+ static const uint64_t BitmanipMasks[] = {
2783
+ 0x5555555555555555ULL , 0x3333333333333333ULL , 0x0F0F0F0F0F0F0F0FULL ,
2784
+ 0x00FF00FF00FF00FFULL , 0x0000FFFF0000FFFFULL , 0x00000000FFFFFFFFULL };
2785
+
2786
+ return matchRISCVBitmanipPat (Op, BitmanipMasks);
2787
+ }
2788
+
2757
2789
// Match the following pattern as a GREVI(W) operation
2758
2790
// (or (BITMANIP_SHL x), (BITMANIP_SRL x))
2759
2791
static SDValue combineORToGREV (SDValue Op, SelectionDAG &DAG,
2760
2792
const RISCVSubtarget &Subtarget) {
2793
+ assert (Subtarget.hasStdExtZbp () && " Expected Zbp extenson" );
2761
2794
EVT VT = Op.getValueType ();
2762
2795
2763
2796
if (VT == Subtarget.getXLenVT () || (Subtarget.is64Bit () && VT == MVT::i32 )) {
2764
- auto LHS = matchRISCVBitmanipPat (Op.getOperand (0 ));
2765
- auto RHS = matchRISCVBitmanipPat (Op.getOperand (1 ));
2797
+ auto LHS = matchGREVIPat (Op.getOperand (0 ));
2798
+ auto RHS = matchGREVIPat (Op.getOperand (1 ));
2766
2799
if (LHS && RHS && LHS->formsPairWith (*RHS)) {
2767
2800
SDLoc DL (Op);
2768
2801
return DAG.getNode (
@@ -2784,6 +2817,7 @@ static SDValue combineORToGREV(SDValue Op, SelectionDAG &DAG,
2784
2817
// 4. (or (rotl/rotr x, bitwidth/2), x)
2785
2818
static SDValue combineORToGORC (SDValue Op, SelectionDAG &DAG,
2786
2819
const RISCVSubtarget &Subtarget) {
2820
+ assert (Subtarget.hasStdExtZbp () && " Expected Zbp extenson" );
2787
2821
EVT VT = Op.getValueType ();
2788
2822
2789
2823
if (VT == Subtarget.getXLenVT () || (Subtarget.is64Bit () && VT == MVT::i32 )) {
@@ -2822,14 +2856,14 @@ static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
2822
2856
return SDValue ();
2823
2857
SDValue OrOp0 = Op0.getOperand (0 );
2824
2858
SDValue OrOp1 = Op0.getOperand (1 );
2825
- auto LHS = matchRISCVBitmanipPat (OrOp0);
2859
+ auto LHS = matchGREVIPat (OrOp0);
2826
2860
// OR is commutable so swap the operands and try again: x might have been
2827
2861
// on the left
2828
2862
if (!LHS) {
2829
2863
std::swap (OrOp0, OrOp1);
2830
- LHS = matchRISCVBitmanipPat (OrOp0);
2864
+ LHS = matchGREVIPat (OrOp0);
2831
2865
}
2832
- auto RHS = matchRISCVBitmanipPat (Op1);
2866
+ auto RHS = matchGREVIPat (Op1);
2833
2867
if (LHS && RHS && LHS->formsPairWith (*RHS) && LHS->Op == OrOp1) {
2834
2868
return DAG.getNode (
2835
2869
RISCVISD::GORCI, DL, VT, LHS->Op ,
@@ -2839,6 +2873,102 @@ static SDValue combineORToGORC(SDValue Op, SelectionDAG &DAG,
2839
2873
return SDValue ();
2840
2874
}
2841
2875
2876
+ // Matches any of the following bit-manipulation patterns:
2877
+ // (and (shl x, 1), (0x22222222 << 1))
2878
+ // (and (srl x, 1), 0x22222222)
2879
+ // (shl (and x, 0x22222222), 1)
2880
+ // (srl (and x, (0x22222222 << 1)), 1)
2881
+ // where the shift amount and mask may vary thus:
2882
+ // [1] = 0x22222222 / 0x44444444
2883
+ // [2] = 0x0C0C0C0C / 0x3C3C3C3C
2884
+ // [4] = 0x00F000F0 / 0x0F000F00
2885
+ // [8] = 0x0000FF00 / 0x00FF0000
2886
+ // [16] = 0x00000000FFFF0000 / 0x0000FFFF00000000 (for RV64)
2887
+ static Optional<RISCVBitmanipPat> matchSHFLPat (SDValue Op) {
2888
+ // These are the unshifted masks which we use to match bit-manipulation
2889
+ // patterns. They may be shifted left in certain circumstances.
2890
+ static const uint64_t BitmanipMasks[] = {
2891
+ 0x2222222222222222ULL , 0x0C0C0C0C0C0C0C0CULL , 0x00F000F000F000F0ULL ,
2892
+ 0x0000FF000000FF00ULL , 0x00000000FFFF0000ULL };
2893
+
2894
+ return matchRISCVBitmanipPat (Op, BitmanipMasks);
2895
+ }
2896
+
2897
+ // Match (or (or (SHFL_SHL x), (SHFL_SHR x)), (SHFL_AND x)
2898
+ static SDValue combineORToSHFL (SDValue Op, SelectionDAG &DAG,
2899
+ const RISCVSubtarget &Subtarget) {
2900
+ assert (Subtarget.hasStdExtZbp () && " Expected Zbp extenson" );
2901
+ EVT VT = Op.getValueType ();
2902
+
2903
+ if (VT != MVT::i32 && VT != Subtarget.getXLenVT ())
2904
+ return SDValue ();
2905
+
2906
+ SDValue Op0 = Op.getOperand (0 );
2907
+ SDValue Op1 = Op.getOperand (1 );
2908
+
2909
+ // Or is commutable so canonicalize the second OR to the LHS.
2910
+ if (Op0.getOpcode () != ISD::OR)
2911
+ std::swap (Op0, Op1);
2912
+ if (Op0.getOpcode () != ISD::OR)
2913
+ return SDValue ();
2914
+
2915
+ // We found an inner OR, so our operands are the operands of the inner OR
2916
+ // and the other operand of the outer OR.
2917
+ SDValue A = Op0.getOperand (0 );
2918
+ SDValue B = Op0.getOperand (1 );
2919
+ SDValue C = Op1;
2920
+
2921
+ auto Match1 = matchSHFLPat (A);
2922
+ auto Match2 = matchSHFLPat (B);
2923
+
2924
+ // If neither matched, we failed.
2925
+ if (!Match1 && !Match2)
2926
+ return SDValue ();
2927
+
2928
+ // We had at least one match. if one failed, try the remaining C operand.
2929
+ if (!Match1) {
2930
+ std::swap (A, C);
2931
+ Match1 = matchSHFLPat (A);
2932
+ if (!Match1)
2933
+ return SDValue ();
2934
+ } else if (!Match2) {
2935
+ std::swap (B, C);
2936
+ Match2 = matchSHFLPat (B);
2937
+ if (!Match2)
2938
+ return SDValue ();
2939
+ }
2940
+ assert (Match1 && Match2);
2941
+
2942
+ // Make sure our matches pair up.
2943
+ if (!Match1->formsPairWith (*Match2))
2944
+ return SDValue ();
2945
+
2946
+ // All the remains is to make sure C is an AND with the same input, that masks
2947
+ // out the bits that are being shuffled.
2948
+ if (C.getOpcode () != ISD::AND || !isa<ConstantSDNode>(C.getOperand (1 )) ||
2949
+ C.getOperand (0 ) != Match1->Op )
2950
+ return SDValue ();
2951
+
2952
+ uint64_t Mask = C.getConstantOperandVal (1 );
2953
+
2954
+ static const uint64_t BitmanipMasks[] = {
2955
+ 0x9999999999999999ULL , 0xC3C3C3C3C3C3C3C3ULL , 0xF00FF00FF00FF00FULL ,
2956
+ 0xFF0000FFFF0000FFULL , 0xFFFF00000000FFFFULL ,
2957
+ };
2958
+
2959
+ unsigned Width = Op.getValueType () == MVT::i64 ? 64 : 32 ;
2960
+ unsigned MaskIdx = Log2_32 (Match1->ShAmt );
2961
+ uint64_t ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes<uint64_t >(Width);
2962
+
2963
+ if (Mask != ExpMask)
2964
+ return SDValue ();
2965
+
2966
+ SDLoc DL (Op);
2967
+ return DAG.getNode (
2968
+ RISCVISD::SHFLI, DL, VT, Match1->Op ,
2969
+ DAG.getTargetConstant (Match1->ShAmt , DL, Subtarget.getXLenVT ()));
2970
+ }
2971
+
2842
2972
// Combine (GREVI (GREVI x, C2), C1) -> (GREVI x, C1^C2) when C1^C2 is
2843
2973
// non-zero, and to x when it is. Any repeated GREVI stage undoes itself.
2844
2974
// Combine (GORCI (GORCI x, C2), C1) -> (GORCI x, C1|C2). Repeated stage does
@@ -3018,6 +3148,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
3018
3148
return GREV;
3019
3149
if (auto GORC = combineORToGORC (SDValue (N, 0 ), DCI.DAG , Subtarget))
3020
3150
return GORC;
3151
+ if (auto SHFL = combineORToSHFL (SDValue (N, 0 ), DCI.DAG , Subtarget))
3152
+ return SHFL;
3021
3153
break ;
3022
3154
case RISCVISD::SELECT_CC: {
3023
3155
// Transform
@@ -3265,6 +3397,19 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
3265
3397
// more precise answer could be calculated for SRAW depending on known
3266
3398
// bits in the shift amount.
3267
3399
return 33 ;
3400
+ case RISCVISD::SHFLI: {
3401
+ // There is no SHFLIW, but a i64 SHFLI with bit 4 of the control word
3402
+ // cleared doesn't affect bit 31. The upper 32 bits will be shuffled, but
3403
+ // will stay within the upper 32 bits. If there were more than 32 sign bits
3404
+ // before there will be at least 33 sign bits after.
3405
+ if (Op.getValueType () == MVT::i64 &&
3406
+ (Op.getConstantOperandVal (1 ) & 0x10 ) == 0 ) {
3407
+ unsigned Tmp = DAG.ComputeNumSignBits (Op.getOperand (0 ), Depth + 1 );
3408
+ if (Tmp > 32 )
3409
+ return 33 ;
3410
+ }
3411
+ break ;
3412
+ }
3268
3413
case RISCVISD::VMV_X_S:
3269
3414
// The number of sign bits of the scalar result is computed by obtaining the
3270
3415
// element type of the input vector operand, subtracting its width from the
@@ -4928,6 +5073,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
4928
5073
NODE_NAME_CASE (GREVIW)
4929
5074
NODE_NAME_CASE (GORCI)
4930
5075
NODE_NAME_CASE (GORCIW)
5076
+ NODE_NAME_CASE (SHFLI)
4931
5077
NODE_NAME_CASE (VMV_V_X_VL)
4932
5078
NODE_NAME_CASE (VFMV_V_F_VL)
4933
5079
NODE_NAME_CASE (VMV_X_S)
0 commit comments