@@ -550,6 +550,9 @@ namespace {
550
550
SDValue N1);
551
551
SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
552
552
SDValue N1, SDNodeFlags Flags);
553
+ SDValue reassociateReduction(unsigned ResOpc, unsigned Opc, const SDLoc &DL,
554
+ EVT VT, SDValue N0, SDValue N1,
555
+ SDNodeFlags Flags = SDNodeFlags());
553
556
554
557
SDValue visitShiftByConstant(SDNode *N);
555
558
@@ -1310,6 +1313,25 @@ SDValue DAGCombiner::reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
1310
1313
return SDValue();
1311
1314
}
1312
1315
1316
+ // Try to fold Opc(vecreduce(x), vecreduce(y)) -> vecreduce(Opc(x, y))
1317
+ // Note that we only expect Flags to be passed from FP operations. For integer
1318
+ // operations they need to be dropped.
1319
+ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1320
+ const SDLoc &DL, EVT VT, SDValue N0,
1321
+ SDValue N1, SDNodeFlags Flags) {
1322
+ if (N0.getOpcode() == RedOpc && N1.getOpcode() == RedOpc &&
1323
+ N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
1324
+ N0->hasOneUse() && N1->hasOneUse() &&
1325
+ TLI.isOperationLegalOrCustom(Opc, N0.getOperand(0).getValueType()) &&
1326
+ TLI.shouldReassociateReduction(RedOpc, N0.getOperand(0).getValueType())) {
1327
+ SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
1328
+ return DAG.getNode(RedOpc, DL, VT,
1329
+ DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
1330
+ N0.getOperand(0), N1.getOperand(0)));
1331
+ }
1332
+ return SDValue();
1333
+ }
1334
+
1313
1335
SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1314
1336
bool AddTo) {
1315
1337
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -2650,6 +2672,11 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
2650
2672
return Add;
2651
2673
if (SDValue Add = ReassociateAddOr(N1, N0))
2652
2674
return Add;
2675
+
2676
+ // Fold add(vecreduce(x), vecreduce(y)) -> vecreduce(add(x, y))
2677
+ if (SDValue SD =
2678
+ reassociateReduction(ISD::VECREDUCE_ADD, ISD::ADD, DL, VT, N0, N1))
2679
+ return SD;
2653
2680
}
2654
2681
// fold ((0-A) + B) -> B-A
2655
2682
if (N0.getOpcode() == ISD::SUB && isNullOrNullSplat(N0.getOperand(0)))
@@ -4351,6 +4378,11 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
4351
4378
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
4352
4379
return RMUL;
4353
4380
4381
+ // Fold mul(vecreduce(x), vecreduce(y)) -> vecreduce(mul(x, y))
4382
+ if (SDValue SD =
4383
+ reassociateReduction(ISD::VECREDUCE_MUL, ISD::MUL, DL, VT, N0, N1))
4384
+ return SD;
4385
+
4354
4386
// Simplify the operands using demanded-bits information.
4355
4387
if (SimplifyDemandedBits(SDValue(N, 0)))
4356
4388
return SDValue(N, 0);
@@ -5486,6 +5518,25 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
5486
5518
if (SDValue S = PerformUMinFpToSatCombine(N0, N1, N0, N1, ISD::SETULT, DAG))
5487
5519
return S;
5488
5520
5521
+ // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
5522
+ auto ReductionOpcode = [](unsigned Opcode) {
5523
+ switch (Opcode) {
5524
+ case ISD::SMIN:
5525
+ return ISD::VECREDUCE_SMIN;
5526
+ case ISD::SMAX:
5527
+ return ISD::VECREDUCE_SMAX;
5528
+ case ISD::UMIN:
5529
+ return ISD::VECREDUCE_UMIN;
5530
+ case ISD::UMAX:
5531
+ return ISD::VECREDUCE_UMAX;
5532
+ default:
5533
+ llvm_unreachable("Unexpected opcode");
5534
+ }
5535
+ };
5536
+ if (SDValue SD = reassociateReduction(ReductionOpcode(Opcode), Opcode,
5537
+ SDLoc(N), VT, N0, N1))
5538
+ return SD;
5539
+
5489
5540
// Simplify the operands using demanded-bits information.
5490
5541
if (SimplifyDemandedBits(SDValue(N, 0)))
5491
5542
return SDValue(N, 0);
@@ -6525,6 +6576,11 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
6525
6576
if (SDValue RAND = reassociateOps(ISD::AND, SDLoc(N), N0, N1, N->getFlags()))
6526
6577
return RAND;
6527
6578
6579
+ // Fold and(vecreduce(x), vecreduce(y)) -> vecreduce(and(x, y))
6580
+ if (SDValue SD = reassociateReduction(ISD::VECREDUCE_AND, ISD::AND, SDLoc(N),
6581
+ VT, N0, N1))
6582
+ return SD;
6583
+
6528
6584
// fold (and (or x, C), D) -> D if (C & D) == D
6529
6585
auto MatchSubset = [](ConstantSDNode *LHS, ConstantSDNode *RHS) {
6530
6586
return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
@@ -7419,6 +7475,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
7419
7475
if (SDValue ROR = reassociateOps(ISD::OR, SDLoc(N), N0, N1, N->getFlags()))
7420
7476
return ROR;
7421
7477
7478
+ // Fold or(vecreduce(x), vecreduce(y)) -> vecreduce(or(x, y))
7479
+ if (SDValue SD = reassociateReduction(ISD::VECREDUCE_OR, ISD::OR, SDLoc(N),
7480
+ VT, N0, N1))
7481
+ return SD;
7482
+
7422
7483
// Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2)
7423
7484
// iff (c1 & c2) != 0 or c1/c2 are undef.
7424
7485
auto MatchIntersect = [](ConstantSDNode *C1, ConstantSDNode *C2) {
@@ -8903,6 +8964,11 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
8903
8964
if (SDValue RXOR = reassociateOps(ISD::XOR, DL, N0, N1, N->getFlags()))
8904
8965
return RXOR;
8905
8966
8967
+ // Fold xor(vecreduce(x), vecreduce(y)) -> vecreduce(xor(x, y))
8968
+ if (SDValue SD =
8969
+ reassociateReduction(ISD::VECREDUCE_XOR, ISD::XOR, DL, VT, N0, N1))
8970
+ return SD;
8971
+
8906
8972
// fold (a^b) -> (a|b) iff a and b share no bits.
8907
8973
if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
8908
8974
DAG.haveNoCommonBitsSet(N0, N1))
@@ -15621,6 +15687,11 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
15621
15687
DAG.getConstantFP(4.0, DL, VT));
15622
15688
}
15623
15689
}
15690
+
15691
+ // Fold fadd(vecreduce(x), vecreduce(y)) -> vecreduce(fadd(x, y))
15692
+ if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FADD, ISD::FADD, DL,
15693
+ VT, N0, N1, Flags))
15694
+ return SD;
15624
15695
} // enable-unsafe-fp-math
15625
15696
15626
15697
// FADD -> FMA combines:
@@ -15795,6 +15866,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
15795
15866
SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
15796
15867
return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
15797
15868
}
15869
+
15870
+ // Fold fmul(vecreduce(x), vecreduce(y)) -> vecreduce(fmul(x, y))
15871
+ if (SDValue SD = reassociateReduction(ISD::VECREDUCE_FMUL, ISD::FMUL, DL,
15872
+ VT, N0, N1, Flags))
15873
+ return SD;
15798
15874
}
15799
15875
15800
15876
// fold (fmul X, 2.0) -> (fadd X, X)
@@ -16845,6 +16921,14 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
16845
16921
}
16846
16922
}
16847
16923
16924
+ const TargetOptions &Options = DAG.getTarget().Options;
16925
+ if ((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
16926
+ (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
16927
+ if (SDValue SD = reassociateReduction(IsMin ? ISD::VECREDUCE_FMIN
16928
+ : ISD::VECREDUCE_FMAX,
16929
+ Opc, SDLoc(N), VT, N0, N1, Flags))
16930
+ return SD;
16931
+
16848
16932
return SDValue();
16849
16933
}
16850
16934
0 commit comments