Skip to content

Commit 1af3f59

Browse files
committed
[DAG] Fold Op(vecreduce(a), vecreduce(b)) into vecreduce(Op(a,b))
So long as the operation is reassociative, we can reassociate the double vecreduce from for example fadd(vecreduce(a), vecreduce(b)) to vecreduce(fadd(a,b)). This will in general save a few instructions, but some architectures (MVE) require the opposite fold, so a shouldExpandReduction is added to account for it. Only targets that use shouldExpandReduction will be affected. Differential Revision: https://reviews.llvm.org/D141870
1 parent 665ee0c commit 1af3f59

File tree

11 files changed

+321
-419
lines changed

11 files changed

+321
-419
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,12 @@ class TargetLoweringBase {
444444
return true;
445445
}
446446

447+
// Return true if op(vecreduce(x), vecreduce(y)) should be reassociated to
448+
// vecreduce(op(x, y)) for the reduction opcode RedOpc.
449+
virtual bool shouldReassociateReduction(unsigned RedOpc, EVT VT) const {
450+
return true;
451+
}
452+
447453
/// Return true if it is profitable to convert a select of FP constants into
448454
/// a constant pool load whose address depends on the select condition. The
449455
/// parameter may be used to differentiate a select with FP compare from

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ namespace {
550550
SDValue N1);
551551
SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
552552
SDValue N1, SDNodeFlags Flags);
553+
SDValue reassociateReduction(unsigned ResOpc, unsigned Opc, const SDLoc &DL,
554+
EVT VT, SDValue N0, SDValue N1,
555+
SDNodeFlags Flags = SDNodeFlags());
553556

554557
SDValue visitShiftByConstant(SDNode *N);
555558

@@ -1310,6 +1313,25 @@ SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
13101313
return SDValue();
13111314
}
13121315

1316+
// Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1317+
// Note that we only expect Flags to be passed from FP operations. For integer
1318+
// operations they need to be dropped.
1319+
SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1320+
const SDLoc &DL, EVT VT, SDValue N0,
1321+
SDValue N1, SDNodeFlags Flags) {
1322+
if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1323+
N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1324+
N0->hasOneUse() && N1->hasOneUse() &&
1325+
TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1326+
TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1327+
SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1328+
return DAG.getNode(RedOpc, DL, VT,
1329+
DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1330+
N0.getOperand(0), N1.getOperand(0)));
1331+
}
1332+
return SDValue();
1333+
}
1334+
13131335
SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
13141336
bool AddTo) {
13151337
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -2650,6 +2672,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
26502672
return Add;
26512673
if (SDValue Add = ReassociateAddOr(N1, N0))
26522674
return Add;
2675+
2676+
// Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2677+
if (SDValue SD =
2678+
reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2679+
return SD;
26532680
}
26542681
// fold ((0-A) + B) -> B-A
26552682
if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
@@ -4351,6 +4378,11 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
43514378
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
43524379
return RMUL;
43534380

4381+
// Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4382+
if (SDValue SD =
4383+
reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4384+
return SD;
4385+
43544386
// Simplify the operands using demanded-bits information.
43554387
if (SimplifyDemandedBits(SDValue(N, 0)))
43564388
return SDValue(N, 0);
@@ -5486,6 +5518,25 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
54865518
if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
54875519
return S;
54885520

5521+
// Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
5522+
auto ReductionOpcode = [](unsigned Opcode) {
5523+
switch (Opcode) {
5524+
case ISD::SMIN:
5525+
return ISD::VECREDUCE_SMIN;
5526+
case ISD::SMAX:
5527+
return ISD::VECREDUCE_SMAX;
5528+
case ISD::UMIN:
5529+
return ISD::VECREDUCE_UMIN;
5530+
case ISD::UMAX:
5531+
return ISD::VECREDUCE_UMAX;
5532+
default:
5533+
llvm_unreachable("Unexpected opcode");
5534+
}
5535+
};
5536+
if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
5537+
SDLoc(N), VT, N0, N1))
5538+
return SD;
5539+
54895540
// Simplify the operands using demanded-bits information.
54905541
if (SimplifyDemandedBits(SDValue(N, 0)))
54915542
return SDValue(N, 0);
@@ -6525,6 +6576,11 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
65256576
if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
65266577
return RAND;
65276578

6579+
// Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
6580+
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, SDLoc(N),
6581+
VT, N0, N1))
6582+
return SD;
6583+
65286584
// fold (and (or x, C), D) -> D if (C & D) == D
65296585
auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
65306586
return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
@@ -7419,6 +7475,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
74197475
if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
74207476
return ROR;
74217477

7478+
// Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7479+
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, SDLoc(N),
7480+
VT, N0, N1))
7481+
return SD;
7482+
74227483
// Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
74237484
// iff (c1 & c2) != 0 or c1/c2 are undef.
74247485
auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
@@ -8903,6 +8964,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
89038964
if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
89048965
return RXOR;
89058966

8967+
// Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
8968+
if (SDValue SD =
8969+
reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
8970+
return SD;
8971+
89068972
// fold (a^b) -> (a|b) iff a and b share no bits.
89078973
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
89088974
DAG.haveNoCommonBitsSet(N0, N1))
@@ -15621,6 +15687,11 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
1562115687
DAG.getConstantFP(4.0, DL, VT));
1562215688
}
1562315689
}
15690+
15691+
// Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
15692+
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
15693+
VT, N0, N1, Flags))
15694+
return SD;
1562415695
} // enable-unsafe-fp-math
1562515696

1562615697
// FADD -> FMA combines:
@@ -15795,6 +15866,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
1579515866
SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
1579615867
return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
1579715868
}
15869+
15870+
// Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
15871+
if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
15872+
VT, N0, N1, Flags))
15873+
return SD;
1579815874
}
1579915875

1580015876
// fold (fmul X, 2.0) -> (fadd X, X)
@@ -16845,6 +16921,14 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
1684516921
}
1684616922
}
1684716923

16924+
const TargetOptions &Options = DAG.getTarget().Options;
16925+
if ((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16926+
(Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
16927+
if (SDValue SD = reassociateReduction(IsMin ? ISD::VECREDUCE_FMIN
16928+
: ISD::VECREDUCE_FMAX,
16929+
Opc, SDLoc(N), VT, N0, N1, Flags))
16930+
return SD;
16931+
1684816932
return SDValue();
1684916933
}
1685016934

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ class VectorType;
617617
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
618618
}
619619

620+
bool shouldReassociateReduction(unsigned Opc, EVT VT) const override {
621+
return Opc != ISD::VECREDUCE_ADD;
622+
}
623+
620624
/// Returns true if an argument of type Ty needs to be passed in a
621625
/// contiguous block of registers in calling convention CallConv.
622626
bool functionArgumentNeedsConsecutiveRegisters(

llvm/test/CodeGen/AArch64/aarch64-addv.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,9 @@ define i32 @oversized_ADDV_512(ptr %arr) {
102102
define i8 @addv_combine_i8(<8 x i8> %a1, <8 x i8> %a2) {
103103
; CHECK-LABEL: addv_combine_i8:
104104
; CHECK: // %bb.0: // %entry
105+
; CHECK-NEXT: add v0.8b, v0.8b, v1.8b
105106
; CHECK-NEXT: addv b0, v0.8b
106-
; CHECK-NEXT: addv b1, v1.8b
107-
; CHECK-NEXT: fmov w8, s0
108-
; CHECK-NEXT: fmov w9, s1
109-
; CHECK-NEXT: add w0, w8, w9
107+
; CHECK-NEXT: fmov w0, s0
110108
; CHECK-NEXT: ret
111109
entry:
112110
%rdx.1 = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %a1)
@@ -118,11 +116,9 @@ entry:
118116
define i16 @addv_combine_i16(<4 x i16> %a1, <4 x i16> %a2) {
119117
; CHECK-LABEL: addv_combine_i16:
120118
; CHECK: // %bb.0: // %entry
119+
; CHECK-NEXT: add v0.4h, v0.4h, v1.4h
121120
; CHECK-NEXT: addv h0, v0.4h
122-
; CHECK-NEXT: addv h1, v1.4h
123-
; CHECK-NEXT: fmov w8, s0
124-
; CHECK-NEXT: fmov w9, s1
125-
; CHECK-NEXT: add w0, w8, w9
121+
; CHECK-NEXT: fmov w0, s0
126122
; CHECK-NEXT: ret
127123
entry:
128124
%rdx.1 = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %a1)

0 commit comments

Comments
 (0)