Skip to content

Commit 112e49b

Browse files
committed
[DAGCombiner] Transform (icmp eq/ne (and X,C0),(shift X,C1)) to use rotate or to getter constants.
If `C0` is a mask and `C1` shifts out all the masked bits (to essentially compare two subsets of `X`), we can arbitrarily re-order shift as `srl` or `shl`. If `C1` (shift amount) is a power of 2, we can replace the and+shift with a rotate. Otherwise, based on target preference we can arbitrarily swap `shl` and `shl` in/out to get better constants. On x86 we can use this re-ordering to: 1) get better `and` constants for `C0` (zero extended moves or avoid imm64). 2) covert `srl` to `shl` if `shl` will be implementable with `lea` or `add` (both of which can be preferable). Proofs: https://alive2.llvm.org/ce/z/qzGM_w Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D152116
1 parent 0c2d28a commit 112e49b

File tree

5 files changed

+279
-86
lines changed

5 files changed

+279
-86
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,24 @@ class TargetLoweringBase {
828828
return N->getOpcode() == ISD::FDIV;
829829
}
830830

831+
// Given:
832+
// (icmp eq/ne (and X, C0), (shift X, C1))
833+
// or
834+
// (icmp eq/ne X, (rotate X, CPow2))
835+
836+
// If C0 is a mask or shifted mask and the shift amt (C1) isolates the
837+
// remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
838+
// Do we prefer the shift to be shift-right, shift-left, or rotate.
839+
// Note: Its only valid to convert the rotate version to the shift version iff
840+
// the shift-amt (`C1`) is a power of 2 (including 0).
841+
// If ShiftOpc (current Opcode) is returned, do nothing.
842+
virtual unsigned preferedOpcodeForCmpEqPiecesOfOperand(
843+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
844+
const APInt &ShiftOrRotateAmt,
845+
const std::optional<APInt> &AndMask) const {
846+
return ShiftOpc;
847+
}
848+
831849
/// These two forms are equivalent:
832850
/// sub %y, (xor %x, -1)
833851
/// add (add %x, 1), %y

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 115 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12443,27 +12443,127 @@ SDValue DAGCombiner::visitSETCC(SDNode *N) {
1244312443

1244412444
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
1244512445
EVT VT = N->getValueType(0);
12446+
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
1244612447

12447-
SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
12448-
SDLoc(N), !PreferSetCC);
12449-
12450-
if (!Combined)
12451-
return SDValue();
12448+
SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, SDLoc(N), !PreferSetCC);
1245212449

12453-
// If we prefer to have a setcc, and we don't, we'll try our best to
12454-
// recreate one using rebuildSetCC.
12455-
if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12456-
SDValue NewSetCC = rebuildSetCC(Combined);
12450+
if (Combined) {
12451+
// If we prefer to have a setcc, and we don't, we'll try our best to
12452+
// recreate one using rebuildSetCC.
12453+
if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12454+
SDValue NewSetCC = rebuildSetCC(Combined);
1245712455

12458-
// We don't have anything interesting to combine to.
12459-
if (NewSetCC.getNode() == N)
12460-
return SDValue();
12456+
// We don't have anything interesting to combine to.
12457+
if (NewSetCC.getNode() == N)
12458+
return SDValue();
1246112459

12462-
if (NewSetCC)
12463-
return NewSetCC;
12460+
if (NewSetCC)
12461+
return NewSetCC;
12462+
}
12463+
return Combined;
1246412464
}
1246512465

12466-
return Combined;
12466+
// Optimize
12467+
// 1) (icmp eq/ne (and X, C0), (shift X, C1))
12468+
// or
12469+
// 2) (icmp eq/ne X, (rotate X, C1))
12470+
// If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12471+
// remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12472+
// Then:
12473+
// If C1 is a power of 2, then the rotate and shift+and versions are
12474+
// equivilent, so we can interchange them depending on target preference.
12475+
// Otherwise, if we have the shift+and version we can interchange srl/shl
12476+
// which inturn affects the constant C0. We can use this to get better
12477+
// constants again determined by target preference.
12478+
if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12479+
auto IsAndWithShift = [](SDValue A, SDValue B) {
12480+
return A.getOpcode() == ISD::AND &&
12481+
(B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12482+
A.getOperand(0) == B.getOperand(0);
12483+
};
12484+
auto IsRotateWithOp = [](SDValue A, SDValue B) {
12485+
return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12486+
B.getOperand(0) == A;
12487+
};
12488+
SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12489+
bool IsRotate = false;
12490+
12491+
// Find either shift+and or rotate pattern.
12492+
if (IsAndWithShift(N0, N1)) {
12493+
AndOrOp = N0;
12494+
ShiftOrRotate = N1;
12495+
} else if (IsAndWithShift(N1, N0)) {
12496+
AndOrOp = N1;
12497+
ShiftOrRotate = N0;
12498+
} else if (IsRotateWithOp(N0, N1)) {
12499+
IsRotate = true;
12500+
AndOrOp = N0;
12501+
ShiftOrRotate = N1;
12502+
} else if (IsRotateWithOp(N1, N0)) {
12503+
IsRotate = true;
12504+
AndOrOp = N1;
12505+
ShiftOrRotate = N0;
12506+
}
12507+
12508+
if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12509+
(IsRotate || AndOrOp.hasOneUse())) {
12510+
EVT OpVT = N0.getValueType();
12511+
// Get constant shift/rotate amount and possibly mask (if its shift+and
12512+
// variant).
12513+
auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12514+
ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
12515+
/*AllowTrunc*/ false);
12516+
if (CNode == nullptr)
12517+
return std::nullopt;
12518+
return CNode->getAPIntValue();
12519+
};
12520+
std::optional<APInt> AndCMask =
12521+
IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
12522+
std::optional<APInt> ShiftCAmt =
12523+
GetAPIntValue(ShiftOrRotate.getOperand(1));
12524+
unsigned NumBits = OpVT.getScalarSizeInBits();
12525+
12526+
// We found constants.
12527+
if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
12528+
unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12529+
// Check that the constants meet the constraints.
12530+
bool CanTransform =
12531+
IsRotate ||
12532+
(*ShiftCAmt == (~*AndCMask).popcount() && ShiftOpc == ISD::SHL
12533+
? (~*AndCMask).isMask()
12534+
: AndCMask->isMask());
12535+
12536+
// See if target prefers another shift/rotate opcode.
12537+
unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12538+
OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
12539+
// Transform is valid and we have a new preference.
12540+
if (CanTransform && NewShiftOpc != ShiftOpc) {
12541+
SDLoc DL(N);
12542+
SDValue NewShiftOrRotate =
12543+
DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
12544+
ShiftOrRotate.getOperand(1));
12545+
SDValue NewAndOrOp = SDValue();
12546+
12547+
if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12548+
APInt NewMask =
12549+
NewShiftOpc == ISD::SHL
12550+
? APInt::getHighBitsSet(NumBits,
12551+
NumBits - ShiftCAmt->getZExtValue())
12552+
: APInt::getLowBitsSet(NumBits,
12553+
NumBits - ShiftCAmt->getZExtValue());
12554+
NewAndOrOp =
12555+
DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
12556+
DAG.getConstant(NewMask, DL, OpVT));
12557+
} else {
12558+
NewAndOrOp = ShiftOrRotate.getOperand(0);
12559+
}
12560+
12561+
return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
12562+
}
12563+
}
12564+
}
12565+
}
12566+
return SDValue();
1246712567
}
1246812568

1246912569
SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,6 +3257,72 @@ bool X86TargetLowering::
32573257
return NewShiftOpcode == ISD::SHL;
32583258
}
32593259

3260+
unsigned X86TargetLowering::preferedOpcodeForCmpEqPiecesOfOperand(
3261+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
3262+
const APInt &ShiftOrRotateAmt, const std::optional<APInt> &AndMask) const {
3263+
if (!VT.isInteger())
3264+
return ShiftOpc;
3265+
3266+
bool PreferRotate = false;
3267+
if (VT.isVector()) {
3268+
// For vectors, if we have rotate instruction support, then its definetly
3269+
// best. Otherwise its not clear what the best so just don't make changed.
3270+
PreferRotate = Subtarget.hasAVX512() && (VT.getScalarType() == MVT::i32 ||
3271+
VT.getScalarType() == MVT::i64);
3272+
} else {
3273+
// For scalar, if we have bmi prefer rotate for rorx. Otherwise prefer
3274+
// rotate unless we have a zext mask+shr.
3275+
PreferRotate = Subtarget.hasBMI2();
3276+
if (!PreferRotate) {
3277+
unsigned MaskBits =
3278+
VT.getScalarSizeInBits() - ShiftOrRotateAmt.getZExtValue();
3279+
PreferRotate = (MaskBits != 8) && (MaskBits != 16) && (MaskBits != 32);
3280+
}
3281+
}
3282+
3283+
if (ShiftOpc == ISD::SHL || ShiftOpc == ISD::SRL) {
3284+
assert(AndMask.has_value() && "Null andmask when querying about shift+and");
3285+
3286+
if (PreferRotate && MayTransformRotate)
3287+
return ISD::ROTL;
3288+
3289+
// If vector we don't really get much benefit swapping around constants.
3290+
// Maybe we could check if the DAG has the flipped node already in the
3291+
// future.
3292+
if (VT.isVector())
3293+
return ShiftOpc;
3294+
3295+
// See if the beneficial to swap shift type.
3296+
if (ShiftOpc == ISD::SHL) {
3297+
// If the current setup has imm64 mask, then inverse will have
3298+
// at least imm32 mask (or be zext i32 -> i64).
3299+
if (VT == MVT::i64)
3300+
return AndMask->getSignificantBits() > 32 ? ISD::SRL : ShiftOpc;
3301+
3302+
// We can only benefit if req at least 7-bit for the mask. We
3303+
// don't want to replace shl of 1,2,3 as they can be implemented
3304+
// with lea/add.
3305+
return ShiftOrRotateAmt.uge(7) ? ISD::SRL : ShiftOpc;
3306+
}
3307+
3308+
if (VT == MVT::i64)
3309+
// Keep exactly 32-bit imm64, this is zext i32 -> i64 which is
3310+
// extremely efficient.
3311+
return AndMask->getSignificantBits() > 33 ? ISD::SHL : ShiftOpc;
3312+
3313+
// Keep small shifts as shl so we can generate add/lea.
3314+
return ShiftOrRotateAmt.ult(7) ? ISD::SHL : ShiftOpc;
3315+
}
3316+
3317+
// We prefer rotate for vectors of if we won't get a zext mask with SRL
3318+
// (PreferRotate will be set in the latter case).
3319+
if (PreferRotate || VT.isVector())
3320+
return ShiftOpc;
3321+
3322+
// Non-vector type and we have a zext mask with SRL.
3323+
return ISD::SRL;
3324+
}
3325+
32603326
bool X86TargetLowering::preferScalarizeSplat(SDNode *N) const {
32613327
return N->getOpcode() != ISD::FP_EXTEND;
32623328
}

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,11 @@ namespace llvm {
11381138
unsigned OldShiftOpcode, unsigned NewShiftOpcode,
11391139
SelectionDAG &DAG) const override;
11401140

1141+
unsigned preferedOpcodeForCmpEqPiecesOfOperand(
1142+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
1143+
const APInt &ShiftOrRotateAmt,
1144+
const std::optional<APInt> &AndMask) const override;
1145+
11411146
bool preferScalarizeSplat(SDNode *N) const override;
11421147

11431148
bool shouldFoldConstantShiftPairToMask(const SDNode *N,

0 commit comments

Comments
 (0)