@@ -24838,6 +24838,122 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
24838
24838
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
24839
24839
}
24840
24840
24841
+ // Reassociate the true/false expressions of a CSEL instruction to obtain a
24842
+ // common subexpression with the comparison instruction. For example, change
24843
+ // (CSEL (ADD (ADD x y) -c) f LO (SUBS x c)) to
24844
+ // (CSEL (ADD (SUBS x c) y) f LO (SUBS x c)) such that (SUBS x c) is a common
24845
+ // subexpression.
24846
+ static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
24847
+ SDValue SubsNode = N->getOperand(3);
24848
+ if (SubsNode.getOpcode() != AArch64ISD::SUBS || !SubsNode.hasOneUse())
24849
+ return SDValue();
24850
+ auto *CmpOpConst = dyn_cast<ConstantSDNode>(SubsNode.getOperand(1));
24851
+ if (!CmpOpConst)
24852
+ return SDValue();
24853
+
24854
+ auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
24855
+ bool IsEquality = CC == AArch64CC::EQ || CC == AArch64CC::NE;
24856
+ if (IsEquality && !CmpOpConst->isZero())
24857
+ return SDValue();
24858
+
24859
+ SDValue CmpOpOther = SubsNode.getOperand(0);
24860
+ EVT VT = N->getValueType(0);
24861
+
24862
+ auto Reassociate = [&](SDValue Op, APInt ExpectedConst, SDValue NewCmp) {
24863
+ if (Op.getOpcode() != ISD::ADD)
24864
+ return SDValue();
24865
+ auto *AddOpConst = dyn_cast<ConstantSDNode>(Op.getOperand(1));
24866
+ if (!AddOpConst)
24867
+ return SDValue();
24868
+ if (AddOpConst->getAPIntValue() != ExpectedConst)
24869
+ return SDValue();
24870
+ if (Op.getOperand(0).getOpcode() != ISD::ADD ||
24871
+ !Op.getOperand(0).hasOneUse())
24872
+ return SDValue();
24873
+ SDValue X = Op.getOperand(0).getOperand(0);
24874
+ SDValue Y = Op.getOperand(0).getOperand(1);
24875
+ if (X != CmpOpOther)
24876
+ std::swap(X, Y);
24877
+ if (X != CmpOpOther)
24878
+ return SDValue();
24879
+ SDNodeFlags Flags;
24880
+ if (Op.getOperand(0).getNode()->getFlags().hasNoUnsignedWrap())
24881
+ Flags.setNoUnsignedWrap(true);
24882
+ return DAG.getNode(ISD::ADD, SDLoc(Op), VT, NewCmp.getValue(0), Y, Flags);
24883
+ };
24884
+
24885
+ auto Fold = [&](APInt NewCmpConst, AArch64CC::CondCode NewCC) {
24886
+ SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode),
24887
+ DAG.getVTList(VT, MVT_CC), CmpOpOther,
24888
+ DAG.getConstant(NewCmpConst, SDLoc(CmpOpConst),
24889
+ CmpOpConst->getValueType(0)));
24890
+
24891
+ APInt ExpectedConst = -NewCmpConst;
24892
+ SDValue TValReassoc = Reassociate(N->getOperand(0), ExpectedConst, NewCmp);
24893
+ SDValue FValReassoc = Reassociate(N->getOperand(1), ExpectedConst, NewCmp);
24894
+ if (!TValReassoc && !FValReassoc)
24895
+ return SDValue();
24896
+ if (TValReassoc)
24897
+ DAG.ReplaceAllUsesWith(N->getOperand(0), TValReassoc);
24898
+ else
24899
+ TValReassoc = N->getOperand(0);
24900
+ if (FValReassoc)
24901
+ DAG.ReplaceAllUsesWith(N->getOperand(1), FValReassoc);
24902
+ else
24903
+ FValReassoc = N->getOperand(1);
24904
+ return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc,
24905
+ DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC),
24906
+ NewCmp.getValue(1));
24907
+ };
24908
+
24909
+ // First, try to eliminate the compare instruction by searching for a
24910
+ // subtraction with the same constant.
24911
+ if (!IsEquality) // Not useful for equalities
24912
+ if (SDValue R = Fold(CmpOpConst->getAPIntValue(), CC))
24913
+ return R;
24914
+
24915
+ // Next, search for a subtraction with a slightly different constant. By
24916
+ // adjusting the condition code, we can still eliminate the compare
24917
+ // instruction. Adjusting the constant is only valid if it does not result
24918
+ // in signed/unsigned wrap for signed/unsigned comparisons, respectively.
24919
+ // Since such comparisons are trivially true/false, we should not encounter
24920
+ // them here but check for them nevertheless to be on the safe side.
24921
+ auto CheckedFold = [&](bool Check, APInt NewCmpConst,
24922
+ AArch64CC::CondCode NewCC) {
24923
+ return Check ? Fold(NewCmpConst, NewCC) : SDValue();
24924
+ };
24925
+ switch (CC) {
24926
+ case AArch64CC::EQ:
24927
+ case AArch64CC::LS:
24928
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
24929
+ CmpOpConst->getAPIntValue() + 1, AArch64CC::LO);
24930
+ case AArch64CC::NE:
24931
+ case AArch64CC::HI:
24932
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMaxValue(),
24933
+ CmpOpConst->getAPIntValue() + 1, AArch64CC::HS);
24934
+ case AArch64CC::LO:
24935
+ return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
24936
+ CmpOpConst->getAPIntValue() - 1, AArch64CC::LS);
24937
+ case AArch64CC::HS:
24938
+ return CheckedFold(!CmpOpConst->getAPIntValue().isZero(),
24939
+ CmpOpConst->getAPIntValue() - 1, AArch64CC::HI);
24940
+ case AArch64CC::LT:
24941
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
24942
+ CmpOpConst->getAPIntValue() - 1, AArch64CC::LE);
24943
+ case AArch64CC::LE:
24944
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
24945
+ CmpOpConst->getAPIntValue() + 1, AArch64CC::LT);
24946
+ case AArch64CC::GT:
24947
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMaxSignedValue(),
24948
+ CmpOpConst->getAPIntValue() + 1, AArch64CC::GE);
24949
+ case AArch64CC::GE:
24950
+ return CheckedFold(!CmpOpConst->getAPIntValue().isMinSignedValue(),
24951
+ CmpOpConst->getAPIntValue() - 1, AArch64CC::GT);
24952
+ default:
24953
+ return SDValue();
24954
+ }
24955
+ }
24956
+
24841
24957
// Optimize CSEL instructions
24842
24958
static SDValue performCSELCombine(SDNode *N,
24843
24959
TargetLowering::DAGCombinerInfo &DCI,
@@ -24849,6 +24965,11 @@ static SDValue performCSELCombine(SDNode *N,
24849
24965
if (SDValue R = foldCSELOfCSEL(N, DAG))
24850
24966
return R;
24851
24967
24968
+ // Try to reassociate the true/false expressions so that we can do CSE with
24969
+ // a SUBS instruction used to perform the comparison.
24970
+ if (SDValue R = reassociateCSELOperandsForCSE(N, DAG))
24971
+ return R;
24972
+
24852
24973
// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
24853
24974
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
24854
24975
if (SDValue Folded = foldCSELofCTTZ(N, DAG))
0 commit comments