Skip to content

Commit 364f781

Browse files
authored
[GlobalIsel] Combine logic of icmps (llvm#77855)
Inspired by InstCombinerImpl::foldAndOrOfICmpsUsingRanges with some adaptations to MIR.
1 parent ca1da36 commit 364f781

File tree

7 files changed

+647
-99
lines changed

7 files changed

+647
-99
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,12 @@ class CombinerHelper {
814814
/// Combine selects.
815815
bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);
816816

817+
/// Combine ands,
818+
bool matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo);
819+
820+
/// Combine ors,
821+
bool matchOr(MachineInstr &MI, BuildFnTy &MatchInfo);
822+
817823
private:
818824
/// Checks for legality of an indexed variant of \p LdSt.
819825
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -919,6 +925,12 @@ class CombinerHelper {
919925
bool AllowUndefs);
920926

921927
std::optional<APInt> getConstantOrConstantSplatVector(Register Src);
928+
929+
/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
930+
/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
931+
/// into a single comparison using range-based reasoning.
932+
bool tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
933+
BuildFnTy &MatchInfo);
922934
};
923935
} // namespace llvm
924936

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,134 @@ class GPhi : public GenericMachineInstr {
592592
}
593593
};
594594

595+
/// Represents a binary operation, i.e, x = y op z.
596+
class GBinOp : public GenericMachineInstr {
597+
public:
598+
Register getLHSReg() const { return getReg(1); }
599+
Register getRHSReg() const { return getReg(2); }
600+
601+
static bool classof(const MachineInstr *MI) {
602+
switch (MI->getOpcode()) {
603+
// Integer.
604+
case TargetOpcode::G_ADD:
605+
case TargetOpcode::G_SUB:
606+
case TargetOpcode::G_MUL:
607+
case TargetOpcode::G_SDIV:
608+
case TargetOpcode::G_UDIV:
609+
case TargetOpcode::G_SREM:
610+
case TargetOpcode::G_UREM:
611+
case TargetOpcode::G_SMIN:
612+
case TargetOpcode::G_SMAX:
613+
case TargetOpcode::G_UMIN:
614+
case TargetOpcode::G_UMAX:
615+
// Floating point.
616+
case TargetOpcode::G_FMINNUM:
617+
case TargetOpcode::G_FMAXNUM:
618+
case TargetOpcode::G_FMINNUM_IEEE:
619+
case TargetOpcode::G_FMAXNUM_IEEE:
620+
case TargetOpcode::G_FMINIMUM:
621+
case TargetOpcode::G_FMAXIMUM:
622+
case TargetOpcode::G_FADD:
623+
case TargetOpcode::G_FSUB:
624+
case TargetOpcode::G_FMUL:
625+
case TargetOpcode::G_FDIV:
626+
case TargetOpcode::G_FPOW:
627+
// Logical.
628+
case TargetOpcode::G_AND:
629+
case TargetOpcode::G_OR:
630+
case TargetOpcode::G_XOR:
631+
return true;
632+
default:
633+
return false;
634+
}
635+
};
636+
};
637+
638+
/// Represents an integer binary operation.
639+
class GIntBinOp : public GBinOp {
640+
public:
641+
static bool classof(const MachineInstr *MI) {
642+
switch (MI->getOpcode()) {
643+
case TargetOpcode::G_ADD:
644+
case TargetOpcode::G_SUB:
645+
case TargetOpcode::G_MUL:
646+
case TargetOpcode::G_SDIV:
647+
case TargetOpcode::G_UDIV:
648+
case TargetOpcode::G_SREM:
649+
case TargetOpcode::G_UREM:
650+
case TargetOpcode::G_SMIN:
651+
case TargetOpcode::G_SMAX:
652+
case TargetOpcode::G_UMIN:
653+
case TargetOpcode::G_UMAX:
654+
return true;
655+
default:
656+
return false;
657+
}
658+
};
659+
};
660+
661+
/// Represents a floating point binary operation.
662+
class GFBinOp : public GBinOp {
663+
public:
664+
static bool classof(const MachineInstr *MI) {
665+
switch (MI->getOpcode()) {
666+
case TargetOpcode::G_FMINNUM:
667+
case TargetOpcode::G_FMAXNUM:
668+
case TargetOpcode::G_FMINNUM_IEEE:
669+
case TargetOpcode::G_FMAXNUM_IEEE:
670+
case TargetOpcode::G_FMINIMUM:
671+
case TargetOpcode::G_FMAXIMUM:
672+
case TargetOpcode::G_FADD:
673+
case TargetOpcode::G_FSUB:
674+
case TargetOpcode::G_FMUL:
675+
case TargetOpcode::G_FDIV:
676+
case TargetOpcode::G_FPOW:
677+
return true;
678+
default:
679+
return false;
680+
}
681+
};
682+
};
683+
684+
/// Represents a logical binary operation.
685+
class GLogicalBinOp : public GBinOp {
686+
public:
687+
static bool classof(const MachineInstr *MI) {
688+
switch (MI->getOpcode()) {
689+
case TargetOpcode::G_AND:
690+
case TargetOpcode::G_OR:
691+
case TargetOpcode::G_XOR:
692+
return true;
693+
default:
694+
return false;
695+
}
696+
};
697+
};
698+
699+
/// Represents an integer addition.
700+
class GAdd : public GIntBinOp {
701+
public:
702+
static bool classof(const MachineInstr *MI) {
703+
return MI->getOpcode() == TargetOpcode::G_ADD;
704+
};
705+
};
706+
707+
/// Represents a logical and.
708+
class GAnd : public GLogicalBinOp {
709+
public:
710+
static bool classof(const MachineInstr *MI) {
711+
return MI->getOpcode() == TargetOpcode::G_AND;
712+
};
713+
};
714+
715+
/// Represents a logical or.
716+
class GOr : public GLogicalBinOp {
717+
public:
718+
static bool classof(const MachineInstr *MI) {
719+
return MI->getOpcode() == TargetOpcode::G_OR;
720+
};
721+
};
722+
595723
} // namespace llvm
596724

597725
#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,18 @@ def match_selects : GICombineRule<
12411241
[{ return Helper.matchSelect(*${root}, ${matchinfo}); }]),
12421242
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
12431243

1244+
def match_ands : GICombineRule<
1245+
(defs root:$root, build_fn_matchinfo:$matchinfo),
1246+
(match (wip_match_opcode G_AND):$root,
1247+
[{ return Helper.matchAnd(*${root}, ${matchinfo}); }]),
1248+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
1249+
1250+
def match_ors : GICombineRule<
1251+
(defs root:$root, build_fn_matchinfo:$matchinfo),
1252+
(match (wip_match_opcode G_OR):$root,
1253+
[{ return Helper.matchOr(*${root}, ${matchinfo}); }]),
1254+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
1255+
12441256
// FIXME: These should use the custom predicate feature once it lands.
12451257
def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
12461258
undef_to_negative_one,
@@ -1314,7 +1326,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
13141326
intdiv_combines, mulh_combines, redundant_neg_operands,
13151327
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
13161328
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
1317-
fsub_to_fneg, commute_constant_to_rhs]>;
1329+
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors]>;
13181330

13191331
// A combine group used to for prelegalizer combiners at -O0. The combines in
13201332
// this group have been selected based on experiments to balance code size and

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
#include "llvm/CodeGen/TargetInstrInfo.h"
2929
#include "llvm/CodeGen/TargetLowering.h"
3030
#include "llvm/CodeGen/TargetOpcodes.h"
31+
#include "llvm/IR/ConstantRange.h"
3132
#include "llvm/IR/DataLayout.h"
3233
#include "llvm/IR/InstrTypes.h"
3334
#include "llvm/Support/Casting.h"
3435
#include "llvm/Support/DivisionByConstantInfo.h"
36+
#include "llvm/Support/ErrorHandling.h"
3537
#include "llvm/Support/MathExtras.h"
3638
#include "llvm/Target/TargetMachine.h"
3739
#include <cmath>
@@ -6651,3 +6653,181 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
66516653

66526654
return false;
66536655
}
6656+
6657+
/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
6658+
/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
6659+
/// into a single comparison using range-based reasoning.
6660+
/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
6661+
bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
6662+
BuildFnTy &MatchInfo) {
6663+
assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
6664+
bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
6665+
Register DstReg = Logic->getReg(0);
6666+
Register LHS = Logic->getLHSReg();
6667+
Register RHS = Logic->getRHSReg();
6668+
unsigned Flags = Logic->getFlags();
6669+
6670+
// We need an G_ICMP on the LHS register.
6671+
GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI);
6672+
if (!Cmp1)
6673+
return false;
6674+
6675+
// We need an G_ICMP on the RHS register.
6676+
GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI);
6677+
if (!Cmp2)
6678+
return false;
6679+
6680+
// We want to fold the icmps.
6681+
if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
6682+
!MRI.hasOneNonDBGUse(Cmp2->getReg(0)))
6683+
return false;
6684+
6685+
APInt C1;
6686+
APInt C2;
6687+
std::optional<ValueAndVReg> MaybeC1 =
6688+
getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI);
6689+
if (!MaybeC1)
6690+
return false;
6691+
C1 = MaybeC1->Value;
6692+
6693+
std::optional<ValueAndVReg> MaybeC2 =
6694+
getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI);
6695+
if (!MaybeC2)
6696+
return false;
6697+
C2 = MaybeC2->Value;
6698+
6699+
Register R1 = Cmp1->getLHSReg();
6700+
Register R2 = Cmp2->getLHSReg();
6701+
CmpInst::Predicate Pred1 = Cmp1->getCond();
6702+
CmpInst::Predicate Pred2 = Cmp2->getCond();
6703+
LLT CmpTy = MRI.getType(Cmp1->getReg(0));
6704+
LLT CmpOperandTy = MRI.getType(R1);
6705+
6706+
// We build ands, adds, and constants of type CmpOperandTy.
6707+
// They must be legal to build.
6708+
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) ||
6709+
!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) ||
6710+
!isConstantLegalOrBeforeLegalizer(CmpOperandTy))
6711+
return false;
6712+
6713+
// Look through add of a constant offset on R1, R2, or both operands. This
6714+
// allows us to interpret the R + C' < C'' range idiom into a proper range.
6715+
std::optional<APInt> Offset1;
6716+
std::optional<APInt> Offset2;
6717+
if (R1 != R2) {
6718+
if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) {
6719+
std::optional<ValueAndVReg> MaybeOffset1 =
6720+
getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
6721+
if (MaybeOffset1) {
6722+
R1 = Add->getLHSReg();
6723+
Offset1 = MaybeOffset1->Value;
6724+
}
6725+
}
6726+
if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) {
6727+
std::optional<ValueAndVReg> MaybeOffset2 =
6728+
getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
6729+
if (MaybeOffset2) {
6730+
R2 = Add->getLHSReg();
6731+
Offset2 = MaybeOffset2->Value;
6732+
}
6733+
}
6734+
}
6735+
6736+
if (R1 != R2)
6737+
return false;
6738+
6739+
// We calculate the icmp ranges including maybe offsets.
6740+
ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
6741+
IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1);
6742+
if (Offset1)
6743+
CR1 = CR1.subtract(*Offset1);
6744+
6745+
ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
6746+
IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2);
6747+
if (Offset2)
6748+
CR2 = CR2.subtract(*Offset2);
6749+
6750+
bool CreateMask = false;
6751+
APInt LowerDiff;
6752+
std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
6753+
if (!CR) {
6754+
// We need non-wrapping ranges.
6755+
if (CR1.isWrappedSet() || CR2.isWrappedSet())
6756+
return false;
6757+
6758+
// Check whether we have equal-size ranges that only differ by one bit.
6759+
// In that case we can apply a mask to map one range onto the other.
6760+
LowerDiff = CR1.getLower() ^ CR2.getLower();
6761+
APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
6762+
APInt CR1Size = CR1.getUpper() - CR1.getLower();
6763+
if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
6764+
CR1Size != CR2.getUpper() - CR2.getLower())
6765+
return false;
6766+
6767+
CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2;
6768+
CreateMask = true;
6769+
}
6770+
6771+
if (IsAnd)
6772+
CR = CR->inverse();
6773+
6774+
CmpInst::Predicate NewPred;
6775+
APInt NewC, Offset;
6776+
CR->getEquivalentICmp(NewPred, NewC, Offset);
6777+
6778+
// We take the result type of one of the original icmps, CmpTy, for
6779+
// the to be build icmp. The operand type, CmpOperandTy, is used for
6780+
// the other instructions and constants to be build. The types of
6781+
// the parameters and output are the same for add and and. CmpTy
6782+
// and the type of DstReg might differ. That is why we zext or trunc
6783+
// the icmp into the destination register.
6784+
6785+
MatchInfo = [=](MachineIRBuilder &B) {
6786+
if (CreateMask && Offset != 0) {
6787+
auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
6788+
auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
6789+
auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
6790+
auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags);
6791+
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
6792+
auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
6793+
B.buildZExtOrTrunc(DstReg, ICmp);
6794+
} else if (CreateMask && Offset == 0) {
6795+
auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
6796+
auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
6797+
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
6798+
auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon);
6799+
B.buildZExtOrTrunc(DstReg, ICmp);
6800+
} else if (!CreateMask && Offset != 0) {
6801+
auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
6802+
auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags);
6803+
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
6804+
auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
6805+
B.buildZExtOrTrunc(DstReg, ICmp);
6806+
} else if (!CreateMask && Offset == 0) {
6807+
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
6808+
auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon);
6809+
B.buildZExtOrTrunc(DstReg, ICmp);
6810+
} else {
6811+
llvm_unreachable("unexpected configuration of CreateMask and Offset");
6812+
}
6813+
};
6814+
return true;
6815+
}
6816+
6817+
bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) {
6818+
GAnd *And = cast<GAnd>(&MI);
6819+
6820+
if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo))
6821+
return true;
6822+
6823+
return false;
6824+
}
6825+
6826+
bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
6827+
GOr *Or = cast<GOr>(&MI);
6828+
6829+
if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo))
6830+
return true;
6831+
6832+
return false;
6833+
}

0 commit comments

Comments
 (0)