Skip to content

Commit 3304d51

Browse files
committed
[RISCV] Add performMULcombine to perform strength-reduction
The RISC-V backend thus far does not provide strength-reduction, which causes a long (but not complete) list of 3-instruction patterns listed to utilize the shift-and-add instruction from Zba and XTHeadBa in strength-reduction. This adds the logic to perform strength-reduction through the DAG combine for ISD::MUL. Initially, we wire this up for XTheadBa only, until this has had some time to settle and get real-world test exposure. The following strength-reductions strategies are currently supported: - XTheadBa - C = (n + 1) // th.addsl - C = (n + 1)k // th.addsl, slli - C = (n + 1)(m + 1) // th.addsl, th.addsl - C = (n + 1)(m + 1)k // th.addsl, th.addsl, slli - C = ((n + 1)m + 1) // th.addsl, th.addsl - C = ((n + 1)m + 1)k // th.addslm th.addsl, slli - base ISA - C being 2 set-bits // slli, slli, add (possibly slli, th.addsl) Even though the slli+slli+add sequence would we supported without XTheadBa, this currently is gated to avoid having to update a large number of test cases (i.e., anything that has a multiplication with a constant where only 2 bits are set) in this commit. With the strength reduction now being performed in performMUL combine, we drop the (now redundant) patterns from RISCVInstrInfoXTHead.td. Depends on D143029 Differential Revision: https://reviews.llvm.org/D143394
1 parent e25b30d commit 3304d51

File tree

2 files changed

+131
-62
lines changed

2 files changed

+131
-62
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10111011
setJumpIsExpensive();
10121012

10131013
setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
1014-
ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
1014+
ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT, ISD::MUL});
10151015
if (Subtarget.is64Bit())
10161016
setTargetDAGCombine(ISD::SRA);
10171017

@@ -8569,6 +8569,134 @@ static SDValue combineDeMorganOfBoolean(SDNode *N, SelectionDAG &DAG) {
85698569
return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT));
85708570
}
85718571

8572+
static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG,
8573+
const RISCVSubtarget &Subtarget) {
8574+
SDLoc DL(N);
8575+
const MVT XLenVT = Subtarget.getXLenVT();
8576+
const EVT VT = N->getValueType(0);
8577+
8578+
// An MUL is usually smaller than any alternative sequence for legal type.
8579+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
8580+
if (DAG.getMachineFunction().getFunction().hasMinSize() &&
8581+
TLI.isOperationLegal(ISD::MUL, VT))
8582+
return SDValue();
8583+
8584+
SDValue N0 = N->getOperand(0);
8585+
SDValue N1 = N->getOperand(1);
8586+
ConstantSDNode *ConstOp = dyn_cast<ConstantSDNode>(N1);
8587+
// Any optimization requires a constant RHS.
8588+
if (!ConstOp)
8589+
return SDValue();
8590+
8591+
const APInt &C = ConstOp->getAPIntValue();
8592+
// A multiply-by-pow2 will be reduced to a shift by the
8593+
// architecture-independent code.
8594+
if (C.isPowerOf2())
8595+
return SDValue();
8596+
8597+
// The below optimizations only work for non-negative constants
8598+
if (!C.isNonNegative())
8599+
return SDValue();
8600+
8601+
auto Shl = [&](SDValue Value, unsigned ShiftAmount) {
8602+
if (!ShiftAmount)
8603+
return Value;
8604+
8605+
SDValue ShiftAmountConst = DAG.getConstant(ShiftAmount, DL, XLenVT);
8606+
return DAG.getNode(ISD::SHL, DL, Value.getValueType(), Value,
8607+
ShiftAmountConst);
8608+
};
8609+
auto Add = [&](SDValue Addend1, SDValue Addend2) {
8610+
return DAG.getNode(ISD::ADD, DL, Addend1.getValueType(), Addend1, Addend2);
8611+
};
8612+
8613+
if (Subtarget.hasVendorXTHeadBa()) {
8614+
// We try to simplify using shift-and-add instructions into up to
8615+
// 3 instructions (e.g. 2x shift-and-add and 1x shift).
8616+
8617+
auto isDivisibleByShiftedAddConst = [&](APInt C, APInt &N,
8618+
APInt &Quotient) {
8619+
unsigned BitWidth = C.getBitWidth();
8620+
for (unsigned i = 3; i >= 1; --i) {
8621+
APInt X(BitWidth, (1 << i) + 1);
8622+
APInt Remainder;
8623+
APInt::sdivrem(C, X, Quotient, Remainder);
8624+
if (Remainder == 0) {
8625+
N = X;
8626+
return true;
8627+
}
8628+
}
8629+
return false;
8630+
};
8631+
auto isShiftedAddConst = [&](APInt C, APInt &N) {
8632+
APInt Quotient;
8633+
return isDivisibleByShiftedAddConst(C, N, Quotient) && Quotient == 1;
8634+
};
8635+
auto isSmallShiftAmount = [](APInt C) {
8636+
return (C == 2) || (C == 4) || (C == 8);
8637+
};
8638+
8639+
auto ShiftAndAdd = [&](SDValue Value, unsigned ShiftAmount,
8640+
SDValue Addend) {
8641+
return Add(Shl(Value, ShiftAmount), Addend);
8642+
};
8643+
auto AnyExt = [&](SDValue Value) {
8644+
return DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Value);
8645+
};
8646+
auto Trunc = [&](SDValue Value) {
8647+
return DAG.getNode(ISD::TRUNCATE, DL, VT, Value);
8648+
};
8649+
8650+
unsigned TrailingZeroes = C.countTrailingZeros();
8651+
const APInt ShiftedC = C.ashr(TrailingZeroes);
8652+
const APInt ShiftedCMinusOne = ShiftedC - 1;
8653+
8654+
// the below comments use the following notation:
8655+
// n, m .. a shift-amount for a shift-and-add instruction
8656+
// (i.e. in { 2, 4, 8 })
8657+
// k .. a power-of-2 that is equivalent to shifting by
8658+
// TrailingZeroes bits
8659+
// i, j .. a power-of-2
8660+
8661+
APInt ShiftAmt1;
8662+
APInt ShiftAmt2;
8663+
APInt Quotient;
8664+
8665+
// C = (m + 1) * k
8666+
if (isShiftedAddConst(ShiftedC, ShiftAmt1)) {
8667+
SDValue Op0 = AnyExt(N0);
8668+
SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
8669+
return Trunc(Shl(Result, TrailingZeroes));
8670+
}
8671+
// C = (m + 1) * (n + 1) * k
8672+
if (isDivisibleByShiftedAddConst(ShiftedC, ShiftAmt1, Quotient) &&
8673+
isShiftedAddConst(Quotient, ShiftAmt2)) {
8674+
SDValue Op0 = AnyExt(N0);
8675+
SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
8676+
Result = ShiftAndAdd(Result, ShiftAmt2.logBase2(), Result);
8677+
return Trunc(Shl(Result, TrailingZeroes));
8678+
}
8679+
// C = ((m + 1) * n + 1) * k
8680+
if (isDivisibleByShiftedAddConst(ShiftedCMinusOne, ShiftAmt1, ShiftAmt2) &&
8681+
isSmallShiftAmount(ShiftAmt2)) {
8682+
SDValue Op0 = AnyExt(N0);
8683+
SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0);
8684+
Result = ShiftAndAdd(Result, Quotient.logBase2(), Op0);
8685+
return Trunc(Shl(Result, TrailingZeroes));
8686+
}
8687+
8688+
// C has 2 bits set: synthesize using 2 shifts and 1 add (which may
8689+
// see one of the shifts merged into a shift-and-add, if feasible)
8690+
if (C.countPopulation() == 2) {
8691+
APInt HighBit(C.getBitWidth(), (1 << C.logBase2()));
8692+
APInt LowBit = C - HighBit;
8693+
return Add(Shl(N0, HighBit.logBase2()), Shl(N0, LowBit.logBase2()));
8694+
}
8695+
}
8696+
8697+
return SDValue();
8698+
}
8699+
85728700
static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG,
85738701
const RISCVSubtarget &Subtarget) {
85748702
SDValue N0 = N->getOperand(0);
@@ -10218,6 +10346,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1021810346
return performADDCombine(N, DAG, Subtarget);
1021910347
case ISD::SUB:
1022010348
return performSUBCombine(N, DAG, Subtarget);
10349+
case ISD::MUL:
10350+
return performMULCombine(N, DAG, Subtarget);
1022110351
case ISD::AND:
1022210352
return performANDCombine(N, DCI, Subtarget);
1022310353
case ISD::OR:

llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -161,67 +161,6 @@ def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2),
161161
(TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>;
162162
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2),
163163
(TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>;
164-
165-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2),
166-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 1)>;
167-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2),
168-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 1)>;
169-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2),
170-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 1)>;
171-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2),
172-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 2)>;
173-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2),
174-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 2)>;
175-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2),
176-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 2)>;
177-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2),
178-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 3)>;
179-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2),
180-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 3)>;
181-
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2),
182-
(TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 3)>;
183-
184-
def : Pat<(add GPR:$r, CSImm12MulBy4:$i),
185-
(TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy2XForm CSImm12MulBy4:$i)), 2)>;
186-
def : Pat<(add GPR:$r, CSImm12MulBy8:$i),
187-
(TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy3XForm CSImm12MulBy8:$i)), 3)>;
188-
189-
def : Pat<(mul GPR:$r, C3LeftShift:$i),
190-
(SLLI (TH_ADDSL GPR:$r, GPR:$r, 1),
191-
(TrailingZeros C3LeftShift:$i))>;
192-
def : Pat<(mul GPR:$r, C5LeftShift:$i),
193-
(SLLI (TH_ADDSL GPR:$r, GPR:$r, 2),
194-
(TrailingZeros C5LeftShift:$i))>;
195-
def : Pat<(mul GPR:$r, C9LeftShift:$i),
196-
(SLLI (TH_ADDSL GPR:$r, GPR:$r, 3),
197-
(TrailingZeros C9LeftShift:$i))>;
198-
199-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)),
200-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 1)>;
201-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)),
202-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
203-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)),
204-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 1), 2)>;
205-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)),
206-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
207-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)),
208-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
209-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)),
210-
(TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>;
211-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)),
212-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 3)>;
213-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)),
214-
(TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
215-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)),
216-
(TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>;
217-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)),
218-
(TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>;
219-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)),
220-
(TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>;
221-
222-
def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)),
223-
(SLLI (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2),
224-
(TH_ADDSL GPR:$r, GPR:$r, 2), 2), 3)>;
225164
} // Predicates = [HasVendorXTHeadBa]
226165

227166
defm PseudoTHVdotVMAQA : VPseudoVMAQA_VV_VX;

0 commit comments

Comments
 (0)