Skip to content

Commit ed7c97e

Browse files
committed
Recommit "[DAGCombiner] Transform (icmp eq/ne (and X,C0),(shift X,C1)) to use rotate or to getter constants." (2nd Try)
Added missing check that the mask and shift amount added up to correct bitwidth as well as test cases for the bug. Closes llvm#71729
1 parent 160a13a commit ed7c97e

File tree

5 files changed

+305
-100
lines changed

5 files changed

+305
-100
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,24 @@ class TargetLoweringBase {
832832
return N->getOpcode() == ISD::FDIV;
833833
}
834834

835+
// Given:
836+
// (icmp eq/ne (and X, C0), (shift X, C1))
837+
// or
838+
// (icmp eq/ne X, (rotate X, CPow2))
839+
840+
// If C0 is a mask or shifted mask and the shift amt (C1) isolates the
841+
// remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
842+
// Do we prefer the shift to be shift-right, shift-left, or rotate.
843+
// Note: Its only valid to convert the rotate version to the shift version iff
844+
// the shift-amt (`C1`) is a power of 2 (including 0).
845+
// If ShiftOpc (current Opcode) is returned, do nothing.
846+
virtual unsigned preferedOpcodeForCmpEqPiecesOfOperand(
847+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
848+
const APInt &ShiftOrRotateAmt,
849+
const std::optional<APInt> &AndMask) const {
850+
return ShiftOpc;
851+
}
852+
835853
/// These two forms are equivalent:
836854
/// sub %y, (xor %x, -1)
837855
/// add (add %x, 1), %y

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12466,27 +12466,132 @@ SDValue DAGCombiner::visitSETCC(SDNode *N) {
1246612466

1246712467
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
1246812468
EVT VT = N->getValueType(0);
12469+
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
1246912470

12470-
SDValue Combined = SimplifySetCC(VT, N->getOperand(0), N->getOperand(1), Cond,
12471-
SDLoc(N), !PreferSetCC);
12472-
12473-
if (!Combined)
12474-
return SDValue();
12471+
SDValue Combined = SimplifySetCC(VT, N0, N1, Cond, SDLoc(N), !PreferSetCC);
1247512472

12476-
// If we prefer to have a setcc, and we don't, we'll try our best to
12477-
// recreate one using rebuildSetCC.
12478-
if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12479-
SDValue NewSetCC = rebuildSetCC(Combined);
12473+
if (Combined) {
12474+
// If we prefer to have a setcc, and we don't, we'll try our best to
12475+
// recreate one using rebuildSetCC.
12476+
if (PreferSetCC && Combined.getOpcode() != ISD::SETCC) {
12477+
SDValue NewSetCC = rebuildSetCC(Combined);
1248012478

12481-
// We don't have anything interesting to combine to.
12482-
if (NewSetCC.getNode() == N)
12483-
return SDValue();
12479+
// We don't have anything interesting to combine to.
12480+
if (NewSetCC.getNode() == N)
12481+
return SDValue();
1248412482

12485-
if (NewSetCC)
12486-
return NewSetCC;
12483+
if (NewSetCC)
12484+
return NewSetCC;
12485+
}
12486+
return Combined;
1248712487
}
1248812488

12489-
return Combined;
12489+
// Optimize
12490+
// 1) (icmp eq/ne (and X, C0), (shift X, C1))
12491+
// or
12492+
// 2) (icmp eq/ne X, (rotate X, C1))
12493+
// If C0 is a mask or shifted mask and the shift amt (C1) isolates the
12494+
// remaining bits (i.e something like `(x64 & UINT32_MAX) == (x64 >> 32)`)
12495+
// Then:
12496+
// If C1 is a power of 2, then the rotate and shift+and versions are
12497+
// equivilent, so we can interchange them depending on target preference.
12498+
// Otherwise, if we have the shift+and version we can interchange srl/shl
12499+
// which inturn affects the constant C0. We can use this to get better
12500+
// constants again determined by target preference.
12501+
if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
12502+
auto IsAndWithShift = [](SDValue A, SDValue B) {
12503+
return A.getOpcode() == ISD::AND &&
12504+
(B.getOpcode() == ISD::SRL || B.getOpcode() == ISD::SHL) &&
12505+
A.getOperand(0) == B.getOperand(0);
12506+
};
12507+
auto IsRotateWithOp = [](SDValue A, SDValue B) {
12508+
return (B.getOpcode() == ISD::ROTL || B.getOpcode() == ISD::ROTR) &&
12509+
B.getOperand(0) == A;
12510+
};
12511+
SDValue AndOrOp = SDValue(), ShiftOrRotate = SDValue();
12512+
bool IsRotate = false;
12513+
12514+
// Find either shift+and or rotate pattern.
12515+
if (IsAndWithShift(N0, N1)) {
12516+
AndOrOp = N0;
12517+
ShiftOrRotate = N1;
12518+
} else if (IsAndWithShift(N1, N0)) {
12519+
AndOrOp = N1;
12520+
ShiftOrRotate = N0;
12521+
} else if (IsRotateWithOp(N0, N1)) {
12522+
IsRotate = true;
12523+
AndOrOp = N0;
12524+
ShiftOrRotate = N1;
12525+
} else if (IsRotateWithOp(N1, N0)) {
12526+
IsRotate = true;
12527+
AndOrOp = N1;
12528+
ShiftOrRotate = N0;
12529+
}
12530+
12531+
if (AndOrOp && ShiftOrRotate && ShiftOrRotate.hasOneUse() &&
12532+
(IsRotate || AndOrOp.hasOneUse())) {
12533+
EVT OpVT = N0.getValueType();
12534+
// Get constant shift/rotate amount and possibly mask (if its shift+and
12535+
// variant).
12536+
auto GetAPIntValue = [](SDValue Op) -> std::optional<APInt> {
12537+
ConstantSDNode *CNode = isConstOrConstSplat(Op, /*AllowUndefs*/ false,
12538+
/*AllowTrunc*/ false);
12539+
if (CNode == nullptr)
12540+
return std::nullopt;
12541+
return CNode->getAPIntValue();
12542+
};
12543+
std::optional<APInt> AndCMask =
12544+
IsRotate ? std::nullopt : GetAPIntValue(AndOrOp.getOperand(1));
12545+
std::optional<APInt> ShiftCAmt =
12546+
GetAPIntValue(ShiftOrRotate.getOperand(1));
12547+
unsigned NumBits = OpVT.getScalarSizeInBits();
12548+
12549+
// We found constants.
12550+
if (ShiftCAmt && (IsRotate || AndCMask) && ShiftCAmt->ult(NumBits)) {
12551+
unsigned ShiftOpc = ShiftOrRotate.getOpcode();
12552+
// Check that the constants meet the constraints.
12553+
bool CanTransform = IsRotate;
12554+
if (!CanTransform) {
12555+
// Check that mask and shift compliment eachother
12556+
CanTransform = *ShiftCAmt == (~*AndCMask).popcount();
12557+
// Check that we are comparing all bits
12558+
CanTransform &= (*ShiftCAmt + AndCMask->popcount()) == NumBits;
12559+
// Check that the and mask is correct for the shift
12560+
CanTransform &=
12561+
ShiftOpc == ISD::SHL ? (~*AndCMask).isMask() : AndCMask->isMask();
12562+
}
12563+
12564+
// See if target prefers another shift/rotate opcode.
12565+
unsigned NewShiftOpc = TLI.preferedOpcodeForCmpEqPiecesOfOperand(
12566+
OpVT, ShiftOpc, ShiftCAmt->isPowerOf2(), *ShiftCAmt, AndCMask);
12567+
// Transform is valid and we have a new preference.
12568+
if (CanTransform && NewShiftOpc != ShiftOpc) {
12569+
SDLoc DL(N);
12570+
SDValue NewShiftOrRotate =
12571+
DAG.getNode(NewShiftOpc, DL, OpVT, ShiftOrRotate.getOperand(0),
12572+
ShiftOrRotate.getOperand(1));
12573+
SDValue NewAndOrOp = SDValue();
12574+
12575+
if (NewShiftOpc == ISD::SHL || NewShiftOpc == ISD::SRL) {
12576+
APInt NewMask =
12577+
NewShiftOpc == ISD::SHL
12578+
? APInt::getHighBitsSet(NumBits,
12579+
NumBits - ShiftCAmt->getZExtValue())
12580+
: APInt::getLowBitsSet(NumBits,
12581+
NumBits - ShiftCAmt->getZExtValue());
12582+
NewAndOrOp =
12583+
DAG.getNode(ISD::AND, DL, OpVT, ShiftOrRotate.getOperand(0),
12584+
DAG.getConstant(NewMask, DL, OpVT));
12585+
} else {
12586+
NewAndOrOp = ShiftOrRotate.getOperand(0);
12587+
}
12588+
12589+
return DAG.getSetCC(DL, VT, NewAndOrOp, NewShiftOrRotate, Cond);
12590+
}
12591+
}
12592+
}
12593+
}
12594+
return SDValue();
1249012595
}
1249112596

1249212597
SDValue DAGCombiner::visitSETCCCARRY(SDNode *N) {

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3263,6 +3263,73 @@ bool X86TargetLowering::
32633263
return NewShiftOpcode == ISD::SHL;
32643264
}
32653265

3266+
unsigned X86TargetLowering::preferedOpcodeForCmpEqPiecesOfOperand(
3267+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
3268+
const APInt &ShiftOrRotateAmt, const std::optional<APInt> &AndMask) const {
3269+
if (!VT.isInteger())
3270+
return ShiftOpc;
3271+
3272+
bool PreferRotate = false;
3273+
if (VT.isVector()) {
3274+
// For vectors, if we have rotate instruction support, then its definetly
3275+
// best. Otherwise its not clear what the best so just don't make changed.
3276+
PreferRotate = Subtarget.hasAVX512() && (VT.getScalarType() == MVT::i32 ||
3277+
VT.getScalarType() == MVT::i64);
3278+
} else {
3279+
// For scalar, if we have bmi prefer rotate for rorx. Otherwise prefer
3280+
// rotate unless we have a zext mask+shr.
3281+
PreferRotate = Subtarget.hasBMI2();
3282+
if (!PreferRotate) {
3283+
unsigned MaskBits =
3284+
VT.getScalarSizeInBits() - ShiftOrRotateAmt.getZExtValue();
3285+
PreferRotate = (MaskBits != 8) && (MaskBits != 16) && (MaskBits != 32);
3286+
}
3287+
}
3288+
3289+
if (ShiftOpc == ISD::SHL || ShiftOpc == ISD::SRL) {
3290+
assert(AndMask.has_value() && "Null andmask when querying about shift+and");
3291+
3292+
if (PreferRotate && MayTransformRotate)
3293+
return ISD::ROTL;
3294+
3295+
// If vector we don't really get much benefit swapping around constants.
3296+
// Maybe we could check if the DAG has the flipped node already in the
3297+
// future.
3298+
if (VT.isVector())
3299+
return ShiftOpc;
3300+
3301+
// See if the beneficial to swap shift type.
3302+
if (ShiftOpc == ISD::SHL) {
3303+
// If the current setup has imm64 mask, then inverse will have
3304+
// at least imm32 mask (or be zext i32 -> i64).
3305+
if (VT == MVT::i64)
3306+
return AndMask->getSignificantBits() > 32 ? (unsigned)ISD::SRL
3307+
: ShiftOpc;
3308+
3309+
// We can only benefit if req at least 7-bit for the mask. We
3310+
// don't want to replace shl of 1,2,3 as they can be implemented
3311+
// with lea/add.
3312+
return ShiftOrRotateAmt.uge(7) ? (unsigned)ISD::SRL : ShiftOpc;
3313+
}
3314+
3315+
if (VT == MVT::i64)
3316+
// Keep exactly 32-bit imm64, this is zext i32 -> i64 which is
3317+
// extremely efficient.
3318+
return AndMask->getSignificantBits() > 33 ? (unsigned)ISD::SHL : ShiftOpc;
3319+
3320+
// Keep small shifts as shl so we can generate add/lea.
3321+
return ShiftOrRotateAmt.ult(7) ? (unsigned)ISD::SHL : ShiftOpc;
3322+
}
3323+
3324+
// We prefer rotate for vectors of if we won't get a zext mask with SRL
3325+
// (PreferRotate will be set in the latter case).
3326+
if (PreferRotate || VT.isVector())
3327+
return ShiftOpc;
3328+
3329+
// Non-vector type and we have a zext mask with SRL.
3330+
return ISD::SRL;
3331+
}
3332+
32663333
bool X86TargetLowering::preferScalarizeSplat(SDNode *N) const {
32673334
return N->getOpcode() != ISD::FP_EXTEND;
32683335
}

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,11 @@ namespace llvm {
11381138
unsigned OldShiftOpcode, unsigned NewShiftOpcode,
11391139
SelectionDAG &DAG) const override;
11401140

1141+
unsigned preferedOpcodeForCmpEqPiecesOfOperand(
1142+
EVT VT, unsigned ShiftOpc, bool MayTransformRotate,
1143+
const APInt &ShiftOrRotateAmt,
1144+
const std::optional<APInt> &AndMask) const override;
1145+
11411146
bool preferScalarizeSplat(SDNode *N) const override;
11421147

11431148
bool shouldFoldConstantShiftPairToMask(const SDNode *N,

0 commit comments

Comments
 (0)