@@ -13706,8 +13706,8 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
13706
13706
if (VT != Subtarget.getXLenVT())
13707
13707
return SDValue();
13708
13708
13709
- if (!Subtarget.hasStdExtZba() && !Subtarget.hasVendorXTHeadBa())
13710
- return SDValue ();
13709
+ const bool HasShlAdd =
13710
+ Subtarget.hasStdExtZba() || Subtarget.hasVendorXTHeadBa ();
13711
13711
13712
13712
ConstantSDNode *CNode = dyn_cast<ConstantSDNode>(N->getOperand(1));
13713
13713
if (!CNode)
@@ -13720,107 +13720,123 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
13720
13720
// other target properly freezes X in these cases either.
13721
13721
SDValue X = N->getOperand(0);
13722
13722
13723
- for (uint64_t Divisor : {3, 5, 9}) {
13724
- if (MulAmt % Divisor != 0)
13725
- continue;
13726
- uint64_t MulAmt2 = MulAmt / Divisor;
13727
- // 3/5/9 * 2^N -> shl (shXadd X, X), N
13728
- if (isPowerOf2_64(MulAmt2)) {
13729
- SDLoc DL(N);
13730
- SDValue X = N->getOperand(0);
13731
- // Put the shift first if we can fold a zext into the
13732
- // shift forming a slli.uw.
13733
- if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13734
- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13735
- SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13736
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13737
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13738
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), Shl);
13723
+ if (HasShlAdd) {
13724
+ for (uint64_t Divisor : {3, 5, 9}) {
13725
+ if (MulAmt % Divisor != 0)
13726
+ continue;
13727
+ uint64_t MulAmt2 = MulAmt / Divisor;
13728
+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
13729
+ if (isPowerOf2_64(MulAmt2)) {
13730
+ SDLoc DL(N);
13731
+ SDValue X = N->getOperand(0);
13732
+ // Put the shift first if we can fold a zext into the
13733
+ // shift forming a slli.uw.
13734
+ if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13735
+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13736
+ SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13737
+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13738
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13739
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
13740
+ Shl);
13741
+ }
13742
+ // Otherwise, put rhe shl second so that it can fold with following
13743
+ // instructions (e.g. sext or add).
13744
+ SDValue Mul359 =
13745
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13746
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13747
+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13748
+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13749
+ }
13750
+
13751
+ // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13752
+ if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13753
+ SDLoc DL(N);
13754
+ SDValue Mul359 =
13755
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13756
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13757
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13758
+ DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13759
+ Mul359);
13739
13760
}
13740
- // Otherwise, put rhe shl second so that it can fold with following
13741
- // instructions (e.g. sext or add).
13742
- SDValue Mul359 =
13743
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13744
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13745
- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13746
- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13747
13761
}
13748
13762
13749
- // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13750
- if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13751
- SDLoc DL(N);
13752
- SDValue Mul359 =
13753
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13754
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13755
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13756
- DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13757
- Mul359);
13758
- }
13759
- }
13760
-
13761
- // If this is a power 2 + 2/4/8, we can use a shift followed by a single
13762
- // shXadd. First check if this a sum of two power of 2s because that's
13763
- // easy. Then count how many zeros are up to the first bit.
13764
- if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13765
- unsigned ScaleShift = llvm::countr_zero(MulAmt);
13766
- if (ScaleShift >= 1 && ScaleShift < 4) {
13767
- unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13768
- SDLoc DL(N);
13769
- SDValue Shift1 =
13770
- DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13771
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13772
- DAG.getConstant(ScaleShift, DL, VT), Shift1);
13763
+ // If this is a power 2 + 2/4/8, we can use a shift followed by a single
13764
+ // shXadd. First check if this a sum of two power of 2s because that's
13765
+ // easy. Then count how many zeros are up to the first bit.
13766
+ if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13767
+ unsigned ScaleShift = llvm::countr_zero(MulAmt);
13768
+ if (ScaleShift >= 1 && ScaleShift < 4) {
13769
+ unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13770
+ SDLoc DL(N);
13771
+ SDValue Shift1 =
13772
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13773
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13774
+ DAG.getConstant(ScaleShift, DL, VT), Shift1);
13775
+ }
13773
13776
}
13774
- }
13775
13777
13776
- // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13777
- // This is the two instruction form, there are also three instruction
13778
- // variants we could implement. e.g.
13779
- // (2^(1,2,3) * 3,5,9 + 1) << C2
13780
- // 2^(C1>3) * 3,5,9 +/- 1
13781
- for (uint64_t Divisor : {3, 5, 9}) {
13782
- uint64_t C = MulAmt - 1;
13783
- if (C <= Divisor)
13784
- continue;
13785
- unsigned TZ = llvm::countr_zero(C);
13786
- if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13787
- SDLoc DL(N);
13788
- SDValue Mul359 =
13789
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13790
- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13791
- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13792
- DAG.getConstant(TZ, DL, VT), X);
13778
+ // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13779
+ // This is the two instruction form, there are also three instruction
13780
+ // variants we could implement. e.g.
13781
+ // (2^(1,2,3) * 3,5,9 + 1) << C2
13782
+ // 2^(C1>3) * 3,5,9 +/- 1
13783
+ for (uint64_t Divisor : {3, 5, 9}) {
13784
+ uint64_t C = MulAmt - 1;
13785
+ if (C <= Divisor)
13786
+ continue;
13787
+ unsigned TZ = llvm::countr_zero(C);
13788
+ if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13789
+ SDLoc DL(N);
13790
+ SDValue Mul359 =
13791
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13792
+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13793
+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13794
+ DAG.getConstant(TZ, DL, VT), X);
13795
+ }
13793
13796
}
13794
- }
13795
13797
13796
- // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13797
- if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13798
- unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13799
- if (ScaleShift >= 1 && ScaleShift < 4) {
13800
- unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13801
- SDLoc DL(N);
13802
- SDValue Shift1 =
13803
- DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13804
- return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13805
- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13806
- DAG.getConstant(ScaleShift, DL, VT), X));
13798
+ // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13799
+ if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13800
+ unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13801
+ if (ScaleShift >= 1 && ScaleShift < 4) {
13802
+ unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13803
+ SDLoc DL(N);
13804
+ SDValue Shift1 =
13805
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13806
+ return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13807
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13808
+ DAG.getConstant(ScaleShift, DL, VT), X));
13809
+ }
13807
13810
}
13808
- }
13809
13811
13810
- // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13811
- for (uint64_t Offset : {3, 5, 9}) {
13812
- if (isPowerOf2_64(MulAmt + Offset)) {
13813
- SDLoc DL(N);
13814
- SDValue Shift1 =
13815
- DAG.getNode(ISD::SHL, DL, VT, X,
13816
- DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13817
- SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13818
- DAG.getConstant(Log2_64(Offset - 1), DL, VT),
13819
- X);
13820
- return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13812
+ // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13813
+ for (uint64_t Offset : {3, 5, 9}) {
13814
+ if (isPowerOf2_64(MulAmt + Offset)) {
13815
+ SDLoc DL(N);
13816
+ SDValue Shift1 =
13817
+ DAG.getNode(ISD::SHL, DL, VT, X,
13818
+ DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13819
+ SDValue Mul359 =
13820
+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13821
+ DAG.getConstant(Log2_64(Offset - 1), DL, VT), X);
13822
+ return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13823
+ }
13821
13824
}
13822
13825
}
13823
13826
13827
+ // 2^N - 2^M -> (sub (shl X, C1), (shl X, C2))
13828
+ uint64_t MulAmtLowBit = MulAmt & (-MulAmt);
13829
+ if (isPowerOf2_64(MulAmt + MulAmtLowBit)) {
13830
+ uint64_t ShiftAmt1 = MulAmt + MulAmtLowBit;
13831
+ SDLoc DL(N);
13832
+ SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13833
+ DAG.getConstant(Log2_64(ShiftAmt1), DL, VT));
13834
+ SDValue Shift2 =
13835
+ DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13836
+ DAG.getConstant(Log2_64(MulAmtLowBit), DL, VT));
13837
+ return DAG.getNode(ISD::SUB, DL, VT, Shift1, Shift2);
13838
+ }
13839
+
13824
13840
return SDValue();
13825
13841
}
13826
13842
0 commit comments