@@ -3796,7 +3796,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
3796
3796
/// Match "(X shl/srl V1) & V2" where V2 may not be present.
3797
3797
static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
3798
3798
if (Op.getOpcode() == ISD::AND) {
3799
- if (isa<ConstantSDNode> (Op.getOperand(1))) {
3799
+ if (isConstOrConstSplat (Op.getOperand(1))) {
3800
3800
Mask = Op.getOperand(1);
3801
3801
Op = Op.getOperand(0);
3802
3802
} else {
@@ -3813,105 +3813,106 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
3813
3813
}
3814
3814
3815
3815
// Return true if we can prove that, whenever Neg and Pos are both in the
3816
- // range [0, OpSize ), Neg == (Pos == 0 ? 0 : OpSize - Pos). This means that
3816
+ // range [0, EltSize ), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
3817
3817
// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
3818
3818
//
3819
3819
// (or (shift1 X, Neg), (shift2 X, Pos))
3820
3820
//
3821
3821
// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
3822
- // in direction shift1 by Neg. The range [0, OpSize ) means that we only need
3822
+ // in direction shift1 by Neg. The range [0, EltSize ) means that we only need
3823
3823
// to consider shift amounts with defined behavior.
3824
- static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned OpSize ) {
3825
- // If OpSize is a power of 2 then:
3824
+ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize ) {
3825
+ // If EltSize is a power of 2 then:
3826
3826
//
3827
- // (a) (Pos == 0 ? 0 : OpSize - Pos) == (OpSize - Pos) & (OpSize - 1)
3828
- // (b) Neg == Neg & (OpSize - 1) whenever Neg is in [0, OpSize ).
3827
+ // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
3828
+ // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize ).
3829
3829
//
3830
- // So if OpSize is a power of 2 and Neg is (and Neg', OpSize -1), we check
3830
+ // So if EltSize is a power of 2 and Neg is (and Neg', EltSize -1), we check
3831
3831
// for the stronger condition:
3832
3832
//
3833
- // Neg & (OpSize - 1) == (OpSize - Pos) & (OpSize - 1) [A]
3833
+ // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
3834
3834
//
3835
- // for all Neg and Pos. Since Neg & (OpSize - 1) == Neg' & (OpSize - 1)
3835
+ // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
3836
3836
// we can just replace Neg with Neg' for the rest of the function.
3837
3837
//
3838
3838
// In other cases we check for the even stronger condition:
3839
3839
//
3840
- // Neg == OpSize - Pos [B]
3840
+ // Neg == EltSize - Pos [B]
3841
3841
//
3842
3842
// for all Neg and Pos. Note that the (or ...) then invokes undefined
3843
- // behavior if Pos == 0 (and consequently Neg == OpSize ).
3843
+ // behavior if Pos == 0 (and consequently Neg == EltSize ).
3844
3844
//
3845
- // We could actually use [A] whenever OpSize is a power of 2, but the
3845
+ // We could actually use [A] whenever EltSize is a power of 2, but the
3846
3846
// only extra cases that it would match are those uninteresting ones
3847
3847
// where Neg and Pos are never in range at the same time. E.g. for
3848
- // OpSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
3848
+ // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
3849
3849
// as well as (sub 32, Pos), but:
3850
3850
//
3851
3851
// (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
3852
3852
//
3853
3853
// always invokes undefined behavior for 32-bit X.
3854
3854
//
3855
- // Below, Mask == OpSize - 1 when using [A] and is all-ones otherwise.
3855
+ // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
3856
3856
unsigned MaskLoBits = 0;
3857
- if (Neg.getOpcode() == ISD::AND &&
3858
- isPowerOf2_64(OpSize) &&
3859
- Neg.getOperand(1).getOpcode() == ISD::Constant &&
3860
- cast<ConstantSDNode>(Neg.getOperand(1))->getAPIntValue() == OpSize - 1) {
3861
- Neg = Neg.getOperand(0);
3862
- MaskLoBits = Log2_64(OpSize);
3857
+ if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
3858
+ if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
3859
+ if (NegC->getAPIntValue() == EltSize - 1) {
3860
+ Neg = Neg.getOperand(0);
3861
+ MaskLoBits = Log2_64(EltSize);
3862
+ }
3863
+ }
3863
3864
}
3864
3865
3865
3866
// Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
3866
3867
if (Neg.getOpcode() != ISD::SUB)
3867
3868
return 0;
3868
- ConstantSDNode *NegC = dyn_cast<ConstantSDNode> (Neg.getOperand(0));
3869
+ ConstantSDNode *NegC = isConstOrConstSplat (Neg.getOperand(0));
3869
3870
if (!NegC)
3870
3871
return 0;
3871
3872
SDValue NegOp1 = Neg.getOperand(1);
3872
3873
3873
- // On the RHS of [A], if Pos is Pos' & (OpSize - 1), just replace Pos with
3874
+ // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
3874
3875
// Pos'. The truncation is redundant for the purpose of the equality.
3875
- if (MaskLoBits &&
3876
- Pos.getOpcode() == ISD::AND &&
3877
- Pos.getOperand(1).getOpcode() == ISD::Constant &&
3878
- cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() == OpSize - 1)
3879
- Pos = Pos.getOperand(0);
3876
+ if (MaskLoBits && Pos.getOpcode() == ISD::AND)
3877
+ if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
3878
+ if (PosC->getAPIntValue() == EltSize - 1)
3879
+ Pos = Pos.getOperand(0);
3880
3880
3881
3881
// The condition we need is now:
3882
3882
//
3883
- // (NegC - NegOp1) & Mask == (OpSize - Pos) & Mask
3883
+ // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
3884
3884
//
3885
3885
// If NegOp1 == Pos then we need:
3886
3886
//
3887
- // OpSize & Mask == NegC & Mask
3887
+ // EltSize & Mask == NegC & Mask
3888
3888
//
3889
3889
// (because "x & Mask" is a truncation and distributes through subtraction).
3890
3890
APInt Width;
3891
3891
if (Pos == NegOp1)
3892
3892
Width = NegC->getAPIntValue();
3893
+
3893
3894
// Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
3894
3895
// Then the condition we want to prove becomes:
3895
3896
//
3896
- // (NegC - NegOp1) & Mask == (OpSize - (NegOp1 + PosC)) & Mask
3897
+ // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
3897
3898
//
3898
3899
// which, again because "x & Mask" is a truncation, becomes:
3899
3900
//
3900
- // NegC & Mask == (OpSize - PosC) & Mask
3901
- // OpSize & Mask == (NegC + PosC) & Mask
3902
- else if (Pos.getOpcode() == ISD::ADD &&
3903
- Pos.getOperand(0) == NegOp1 &&
3904
- Pos.getOperand(1).getOpcode() == ISD::Constant)
3905
- Width = (cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() +
3906
- NegC->getAPIntValue()) ;
3907
- else
3901
+ // NegC & Mask == (EltSize - PosC) & Mask
3902
+ // EltSize & Mask == (NegC + PosC) & Mask
3903
+ else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
3904
+ if (ConstantSDNode *PosC = isConstOrConstSplat( Pos.getOperand(1)))
3905
+ Width = PosC->getAPIntValue() + NegC->getAPIntValue();
3906
+ else
3907
+ return false ;
3908
+ } else
3908
3909
return false;
3909
3910
3910
- // Now we just need to check that OpSize & Mask == Width & Mask.
3911
+ // Now we just need to check that EltSize & Mask == Width & Mask.
3911
3912
if (MaskLoBits)
3912
- // Opsize & Mask is 0 since Mask is Opsize - 1.
3913
+ // EltSize & Mask is 0 since Mask is EltSize - 1.
3913
3914
return Width.getLoBits(MaskLoBits) == 0;
3914
- return Width == OpSize ;
3915
+ return Width == EltSize ;
3915
3916
}
3916
3917
3917
3918
// A subroutine of MatchRotate used once we have found an OR of two opposite
@@ -3931,7 +3932,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
3931
3932
// (srl x, (*ext y))) ->
3932
3933
// (rotr x, y) or (rotl x, (sub 32, y))
3933
3934
EVT VT = Shifted.getValueType();
3934
- if (matchRotateSub(InnerPos, InnerNeg, VT.getSizeInBits ())) {
3935
+ if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits ())) {
3935
3936
bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
3936
3937
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
3937
3938
HasPos ? Pos : Neg).getNode();
@@ -3974,38 +3975,37 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
3974
3975
if (RHSShift.getOpcode() == ISD::SHL) {
3975
3976
std::swap(LHS, RHS);
3976
3977
std::swap(LHSShift, RHSShift);
3977
- std::swap(LHSMask , RHSMask );
3978
+ std::swap(LHSMask, RHSMask);
3978
3979
}
3979
3980
3980
- unsigned OpSizeInBits = VT.getSizeInBits ();
3981
+ unsigned EltSizeInBits = VT.getScalarSizeInBits ();
3981
3982
SDValue LHSShiftArg = LHSShift.getOperand(0);
3982
3983
SDValue LHSShiftAmt = LHSShift.getOperand(1);
3983
3984
SDValue RHSShiftArg = RHSShift.getOperand(0);
3984
3985
SDValue RHSShiftAmt = RHSShift.getOperand(1);
3985
3986
3986
3987
// fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
3987
3988
// fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
3988
- if (LHSShiftAmt.getOpcode() == ISD::Constant &&
3989
- RHSShiftAmt.getOpcode() == ISD::Constant) {
3990
- uint64_t LShVal = cast<ConstantSDNode>(LHSShiftAmt)->getZExtValue();
3991
- uint64_t RShVal = cast<ConstantSDNode>(RHSShiftAmt)->getZExtValue();
3992
- if ((LShVal + RShVal) != OpSizeInBits)
3989
+ if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) {
3990
+ uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue();
3991
+ uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue();
3992
+ if ((LShVal + RShVal) != EltSizeInBits)
3993
3993
return nullptr;
3994
3994
3995
3995
SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
3996
3996
LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
3997
3997
3998
3998
// If there is an AND of either shifted operand, apply it to the result.
3999
3999
if (LHSMask.getNode() || RHSMask.getNode()) {
4000
- APInt Mask = APInt::getAllOnesValue(OpSizeInBits );
4000
+ APInt Mask = APInt::getAllOnesValue(EltSizeInBits );
4001
4001
4002
4002
if (LHSMask.getNode()) {
4003
- APInt RHSBits = APInt::getLowBitsSet(OpSizeInBits , LShVal);
4004
- Mask &= cast<ConstantSDNode> (LHSMask)->getAPIntValue() | RHSBits;
4003
+ APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits , LShVal);
4004
+ Mask &= isConstOrConstSplat (LHSMask)->getAPIntValue() | RHSBits;
4005
4005
}
4006
4006
if (RHSMask.getNode()) {
4007
- APInt LHSBits = APInt::getHighBitsSet(OpSizeInBits , RShVal);
4008
- Mask &= cast<ConstantSDNode> (RHSMask)->getAPIntValue() | LHSBits;
4007
+ APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits , RShVal);
4008
+ Mask &= isConstOrConstSplat (RHSMask)->getAPIntValue() | LHSBits;
4009
4009
}
4010
4010
4011
4011
Rot = DAG.getNode(ISD::AND, DL, VT, Rot, DAG.getConstant(Mask, DL, VT));
0 commit comments