Skip to content
This repository was archived by the owner on Apr 23, 2020. It is now read-only.

Commit d0ca754

Browse files
committed
[X86][XOP] Add support for lowering vector rotations
This patch adds support for lowering to the XOP VPROT / VPROTI vector bit rotation instructions. This has required changes to the DAGCombiner rotation pattern matching to support vector types - so far I've only changed it to support splat vectors, but generalising this further is feasible in the future. Differential Revision: http://reviews.llvm.org/D13851 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251188 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 109f9e6 commit d0ca754

File tree

4 files changed

+228
-357
lines changed

4 files changed

+228
-357
lines changed

lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3796,7 +3796,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
37963796
/// Match "(X shl/srl V1) & V2" where V2 may not be present.
37973797
static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
37983798
if (Op.getOpcode() == ISD::AND) {
3799-
if (isa<ConstantSDNode>(Op.getOperand(1))) {
3799+
if (isConstOrConstSplat(Op.getOperand(1))) {
38003800
Mask = Op.getOperand(1);
38013801
Op = Op.getOperand(0);
38023802
} else {
@@ -3813,105 +3813,106 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) {
38133813
}
38143814

38153815
// Return true if we can prove that, whenever Neg and Pos are both in the
3816-
// range [0, OpSize), Neg == (Pos == 0 ? 0 : OpSize - Pos). This means that
3816+
// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that
38173817
// for two opposing shifts shift1 and shift2 and a value X with OpBits bits:
38183818
//
38193819
// (or (shift1 X, Neg), (shift2 X, Pos))
38203820
//
38213821
// reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate
3822-
// in direction shift1 by Neg. The range [0, OpSize) means that we only need
3822+
// in direction shift1 by Neg. The range [0, EltSize) means that we only need
38233823
// to consider shift amounts with defined behavior.
3824-
static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned OpSize) {
3825-
// If OpSize is a power of 2 then:
3824+
static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) {
3825+
// If EltSize is a power of 2 then:
38263826
//
3827-
// (a) (Pos == 0 ? 0 : OpSize - Pos) == (OpSize - Pos) & (OpSize - 1)
3828-
// (b) Neg == Neg & (OpSize - 1) whenever Neg is in [0, OpSize).
3827+
// (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
3828+
// (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize).
38293829
//
3830-
// So if OpSize is a power of 2 and Neg is (and Neg', OpSize-1), we check
3830+
// So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check
38313831
// for the stronger condition:
38323832
//
3833-
// Neg & (OpSize - 1) == (OpSize - Pos) & (OpSize - 1) [A]
3833+
// Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A]
38343834
//
3835-
// for all Neg and Pos. Since Neg & (OpSize - 1) == Neg' & (OpSize - 1)
3835+
// for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1)
38363836
// we can just replace Neg with Neg' for the rest of the function.
38373837
//
38383838
// In other cases we check for the even stronger condition:
38393839
//
3840-
// Neg == OpSize - Pos [B]
3840+
// Neg == EltSize - Pos [B]
38413841
//
38423842
// for all Neg and Pos. Note that the (or ...) then invokes undefined
3843-
// behavior if Pos == 0 (and consequently Neg == OpSize).
3843+
// behavior if Pos == 0 (and consequently Neg == EltSize).
38443844
//
3845-
// We could actually use [A] whenever OpSize is a power of 2, but the
3845+
// We could actually use [A] whenever EltSize is a power of 2, but the
38463846
// only extra cases that it would match are those uninteresting ones
38473847
// where Neg and Pos are never in range at the same time. E.g. for
3848-
// OpSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
3848+
// EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos)
38493849
// as well as (sub 32, Pos), but:
38503850
//
38513851
// (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos))
38523852
//
38533853
// always invokes undefined behavior for 32-bit X.
38543854
//
3855-
// Below, Mask == OpSize - 1 when using [A] and is all-ones otherwise.
3855+
// Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
38563856
unsigned MaskLoBits = 0;
3857-
if (Neg.getOpcode() == ISD::AND &&
3858-
isPowerOf2_64(OpSize) &&
3859-
Neg.getOperand(1).getOpcode() == ISD::Constant &&
3860-
cast<ConstantSDNode>(Neg.getOperand(1))->getAPIntValue() == OpSize - 1) {
3861-
Neg = Neg.getOperand(0);
3862-
MaskLoBits = Log2_64(OpSize);
3857+
if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
3858+
if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
3859+
if (NegC->getAPIntValue() == EltSize - 1) {
3860+
Neg = Neg.getOperand(0);
3861+
MaskLoBits = Log2_64(EltSize);
3862+
}
3863+
}
38633864
}
38643865

38653866
// Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1.
38663867
if (Neg.getOpcode() != ISD::SUB)
38673868
return 0;
3868-
ConstantSDNode *NegC = dyn_cast<ConstantSDNode>(Neg.getOperand(0));
3869+
ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0));
38693870
if (!NegC)
38703871
return 0;
38713872
SDValue NegOp1 = Neg.getOperand(1);
38723873

3873-
// On the RHS of [A], if Pos is Pos' & (OpSize - 1), just replace Pos with
3874+
// On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
38743875
// Pos'. The truncation is redundant for the purpose of the equality.
3875-
if (MaskLoBits &&
3876-
Pos.getOpcode() == ISD::AND &&
3877-
Pos.getOperand(1).getOpcode() == ISD::Constant &&
3878-
cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() == OpSize - 1)
3879-
Pos = Pos.getOperand(0);
3876+
if (MaskLoBits && Pos.getOpcode() == ISD::AND)
3877+
if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
3878+
if (PosC->getAPIntValue() == EltSize - 1)
3879+
Pos = Pos.getOperand(0);
38803880

38813881
// The condition we need is now:
38823882
//
3883-
// (NegC - NegOp1) & Mask == (OpSize - Pos) & Mask
3883+
// (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask
38843884
//
38853885
// If NegOp1 == Pos then we need:
38863886
//
3887-
// OpSize & Mask == NegC & Mask
3887+
// EltSize & Mask == NegC & Mask
38883888
//
38893889
// (because "x & Mask" is a truncation and distributes through subtraction).
38903890
APInt Width;
38913891
if (Pos == NegOp1)
38923892
Width = NegC->getAPIntValue();
3893+
38933894
// Check for cases where Pos has the form (add NegOp1, PosC) for some PosC.
38943895
// Then the condition we want to prove becomes:
38953896
//
3896-
// (NegC - NegOp1) & Mask == (OpSize - (NegOp1 + PosC)) & Mask
3897+
// (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask
38973898
//
38983899
// which, again because "x & Mask" is a truncation, becomes:
38993900
//
3900-
// NegC & Mask == (OpSize - PosC) & Mask
3901-
// OpSize & Mask == (NegC + PosC) & Mask
3902-
else if (Pos.getOpcode() == ISD::ADD &&
3903-
Pos.getOperand(0) == NegOp1 &&
3904-
Pos.getOperand(1).getOpcode() == ISD::Constant)
3905-
Width = (cast<ConstantSDNode>(Pos.getOperand(1))->getAPIntValue() +
3906-
NegC->getAPIntValue());
3907-
else
3901+
// NegC & Mask == (EltSize - PosC) & Mask
3902+
// EltSize & Mask == (NegC + PosC) & Mask
3903+
else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) {
3904+
if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1)))
3905+
Width = PosC->getAPIntValue() + NegC->getAPIntValue();
3906+
else
3907+
return false;
3908+
} else
39083909
return false;
39093910

3910-
// Now we just need to check that OpSize & Mask == Width & Mask.
3911+
// Now we just need to check that EltSize & Mask == Width & Mask.
39113912
if (MaskLoBits)
3912-
// Opsize & Mask is 0 since Mask is Opsize - 1.
3913+
// EltSize & Mask is 0 since Mask is EltSize - 1.
39133914
return Width.getLoBits(MaskLoBits) == 0;
3914-
return Width == OpSize;
3915+
return Width == EltSize;
39153916
}
39163917

39173918
// A subroutine of MatchRotate used once we have found an OR of two opposite
@@ -3931,7 +3932,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
39313932
// (srl x, (*ext y))) ->
39323933
// (rotr x, y) or (rotl x, (sub 32, y))
39333934
EVT VT = Shifted.getValueType();
3934-
if (matchRotateSub(InnerPos, InnerNeg, VT.getSizeInBits())) {
3935+
if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) {
39353936
bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT);
39363937
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
39373938
HasPos ? Pos : Neg).getNode();
@@ -3974,38 +3975,37 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
39743975
if (RHSShift.getOpcode() == ISD::SHL) {
39753976
std::swap(LHS, RHS);
39763977
std::swap(LHSShift, RHSShift);
3977-
std::swap(LHSMask , RHSMask );
3978+
std::swap(LHSMask, RHSMask);
39783979
}
39793980

3980-
unsigned OpSizeInBits = VT.getSizeInBits();
3981+
unsigned EltSizeInBits = VT.getScalarSizeInBits();
39813982
SDValue LHSShiftArg = LHSShift.getOperand(0);
39823983
SDValue LHSShiftAmt = LHSShift.getOperand(1);
39833984
SDValue RHSShiftArg = RHSShift.getOperand(0);
39843985
SDValue RHSShiftAmt = RHSShift.getOperand(1);
39853986

39863987
// fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
39873988
// fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
3988-
if (LHSShiftAmt.getOpcode() == ISD::Constant &&
3989-
RHSShiftAmt.getOpcode() == ISD::Constant) {
3990-
uint64_t LShVal = cast<ConstantSDNode>(LHSShiftAmt)->getZExtValue();
3991-
uint64_t RShVal = cast<ConstantSDNode>(RHSShiftAmt)->getZExtValue();
3992-
if ((LShVal + RShVal) != OpSizeInBits)
3989+
if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) {
3990+
uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue();
3991+
uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue();
3992+
if ((LShVal + RShVal) != EltSizeInBits)
39933993
return nullptr;
39943994

39953995
SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
39963996
LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
39973997

39983998
// If there is an AND of either shifted operand, apply it to the result.
39993999
if (LHSMask.getNode() || RHSMask.getNode()) {
4000-
APInt Mask = APInt::getAllOnesValue(OpSizeInBits);
4000+
APInt Mask = APInt::getAllOnesValue(EltSizeInBits);
40014001

40024002
if (LHSMask.getNode()) {
4003-
APInt RHSBits = APInt::getLowBitsSet(OpSizeInBits, LShVal);
4004-
Mask &= cast<ConstantSDNode>(LHSMask)->getAPIntValue() | RHSBits;
4003+
APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal);
4004+
Mask &= isConstOrConstSplat(LHSMask)->getAPIntValue() | RHSBits;
40054005
}
40064006
if (RHSMask.getNode()) {
4007-
APInt LHSBits = APInt::getHighBitsSet(OpSizeInBits, RShVal);
4008-
Mask &= cast<ConstantSDNode>(RHSMask)->getAPIntValue() | LHSBits;
4007+
APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal);
4008+
Mask &= isConstOrConstSplat(RHSMask)->getAPIntValue() | LHSBits;
40094009
}
40104010

40114011
Rot = DAG.getNode(ISD::AND, DL, VT, Rot, DAG.getConstant(Mask, DL, VT));

lib/Target/X86/X86ISelLowering.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,17 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
10501050
setOperationAction(ISD::SRA, MVT::v4i32, Custom);
10511051
}
10521052

1053+
if (Subtarget->hasXOP()) {
1054+
setOperationAction(ISD::ROTL, MVT::v16i8, Custom);
1055+
setOperationAction(ISD::ROTL, MVT::v8i16, Custom);
1056+
setOperationAction(ISD::ROTL, MVT::v4i32, Custom);
1057+
setOperationAction(ISD::ROTL, MVT::v2i64, Custom);
1058+
setOperationAction(ISD::ROTL, MVT::v32i8, Custom);
1059+
setOperationAction(ISD::ROTL, MVT::v16i16, Custom);
1060+
setOperationAction(ISD::ROTL, MVT::v8i32, Custom);
1061+
setOperationAction(ISD::ROTL, MVT::v4i64, Custom);
1062+
}
1063+
10531064
if (!Subtarget->useSoftFloat() && Subtarget->hasFp256()) {
10541065
addRegisterClass(MVT::v32i8, &X86::VR256RegClass);
10551066
addRegisterClass(MVT::v16i16, &X86::VR256RegClass);
@@ -18817,6 +18828,41 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
1881718828
return SDValue();
1881818829
}
1881918830

18831+
static SDValue LowerRotate(SDValue Op, const X86Subtarget *Subtarget,
18832+
SelectionDAG &DAG) {
18833+
MVT VT = Op.getSimpleValueType();
18834+
SDLoc DL(Op);
18835+
SDValue R = Op.getOperand(0);
18836+
SDValue Amt = Op.getOperand(1);
18837+
unsigned Opc = Op.getOpcode();
18838+
18839+
assert(VT.isVector() && "Custom lowering only for vector rotates!");
18840+
assert(Subtarget->hasXOP() && "XOP support required for vector rotates!");
18841+
assert((Opc == ISD::ROTL) && "Only ROTL supported");
18842+
18843+
// XOP has 128-bit vector variable + immediate rotates.
18844+
// +ve/-ve Amt = rotate left/right.
18845+
18846+
// Split 256-bit integers.
18847+
if (VT.getSizeInBits() == 256)
18848+
return Lower256IntArith(Op, DAG);
18849+
18850+
assert(VT.getSizeInBits() == 128 && "Only rotate 128-bit vectors!");
18851+
18852+
// Attempt to rotate by immediate.
18853+
if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
18854+
if (auto *RotateConst = BVAmt->getConstantSplatNode()) {
18855+
uint64_t RotateAmt = RotateConst->getAPIntValue().getZExtValue();
18856+
assert(RotateAmt < VT.getScalarSizeInBits() && "Rotation out of range");
18857+
return DAG.getNode(X86ISD::VPROTI, DL, VT, R,
18858+
DAG.getConstant(RotateAmt, DL, MVT::i8));
18859+
}
18860+
}
18861+
18862+
// Use general rotate by variable (per-element).
18863+
return DAG.getNode(X86ISD::VPROT, DL, VT, R, Amt);
18864+
}
18865+
1882018866
static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
1882118867
// Lower the "add/sub/mul with overflow" instruction into a regular ins plus
1882218868
// a "setcc" instruction that checks the overflow flag. The "brcond" lowering
@@ -19675,6 +19721,7 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1967519721
case ISD::MUL: return LowerMUL(Op, Subtarget, DAG);
1967619722
case ISD::UMUL_LOHI:
1967719723
case ISD::SMUL_LOHI: return LowerMUL_LOHI(Op, Subtarget, DAG);
19724+
case ISD::ROTL: return LowerRotate(Op, Subtarget, DAG);
1967819725
case ISD::SRA:
1967919726
case ISD::SRL:
1968019727
case ISD::SHL: return LowerShift(Op, Subtarget, DAG);

0 commit comments

Comments
 (0)