Skip to content

Commit 3a106e5

Browse files
authored
[GlobalISel] Fold G_ICMP if possible (#86357)
This patch tries to fold `G_ICMP` if possible.
1 parent 360f7f5 commit 3a106e5

24 files changed

+694
-393
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ std::optional<SmallVector<unsigned>>
315315
ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
316316
std::function<unsigned(APInt)> CB);
317317

318+
std::optional<SmallVector<APInt>>
319+
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
320+
const MachineRegisterInfo &MRI);
321+
318322
/// Test if the given value is known to have exactly one bit set. This differs
319323
/// from computeKnownBits in that it doesn't necessarily determine which bit is
320324
/// set.

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
174174
switch (Opc) {
175175
default:
176176
break;
177+
case TargetOpcode::G_ICMP: {
178+
assert(SrcOps.size() == 3 && "Invalid sources");
179+
assert(DstOps.size() == 1 && "Invalid dsts");
180+
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
181+
182+
if (std::optional<SmallVector<APInt>> Cst =
183+
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
184+
SrcOps[2].getReg(), *getMRI())) {
185+
if (SrcTy.isVector())
186+
return buildBuildVectorConstant(DstOps[0], *Cst);
187+
return buildConstant(DstOps[0], Cst->front());
188+
}
189+
break;
190+
}
177191
case TargetOpcode::G_ADD:
178192
case TargetOpcode::G_PTR_ADD:
179193
case TargetOpcode::G_AND:

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,9 +3768,11 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
37683768
}
37693769
case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
37703770
auto [OldValRes, SuccessRes, Addr, CmpVal, NewVal] = MI.getFirst5Regs();
3771-
MIRBuilder.buildAtomicCmpXchg(OldValRes, Addr, CmpVal, NewVal,
3771+
Register NewOldValRes = MRI.cloneVirtualRegister(OldValRes);
3772+
MIRBuilder.buildAtomicCmpXchg(NewOldValRes, Addr, CmpVal, NewVal,
37723773
**MI.memoperands_begin());
3773-
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, OldValRes, CmpVal);
3774+
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, NewOldValRes, CmpVal);
3775+
MIRBuilder.buildCopy(OldValRes, NewOldValRes);
37743776
MI.eraseFromParent();
37753777
return Legalized;
37763778
}
@@ -3789,8 +3791,12 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
37893791
case G_UADDO: {
37903792
auto [Res, CarryOut, LHS, RHS] = MI.getFirst4Regs();
37913793

3792-
MIRBuilder.buildAdd(Res, LHS, RHS);
3793-
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, RHS);
3794+
Register NewRes = MRI.cloneVirtualRegister(Res);
3795+
3796+
MIRBuilder.buildAdd(NewRes, LHS, RHS);
3797+
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, NewRes, RHS);
3798+
3799+
MIRBuilder.buildCopy(Res, NewRes);
37943800

37953801
MI.eraseFromParent();
37963802
return Legalized;
@@ -3800,6 +3806,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
38003806
const LLT CondTy = MRI.getType(CarryOut);
38013807
const LLT Ty = MRI.getType(Res);
38023808

3809+
Register NewRes = MRI.cloneVirtualRegister(Res);
3810+
38033811
// Initial add of the two operands.
38043812
auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);
38053813

@@ -3808,15 +3816,18 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
38083816

38093817
// Add the sum and the carry.
38103818
auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
3811-
MIRBuilder.buildAdd(Res, TmpRes, ZExtCarryIn);
3819+
MIRBuilder.buildAdd(NewRes, TmpRes, ZExtCarryIn);
38123820

38133821
// Second check for carry. We can only carry if the initial sum is all 1s
38143822
// and the carry is set, resulting in a new sum of 0.
38153823
auto Zero = MIRBuilder.buildConstant(Ty, 0);
3816-
auto ResEqZero = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, Res, Zero);
3824+
auto ResEqZero =
3825+
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, NewRes, Zero);
38173826
auto Carry2 = MIRBuilder.buildAnd(CondTy, ResEqZero, CarryIn);
38183827
MIRBuilder.buildOr(CarryOut, Carry, Carry2);
38193828

3829+
MIRBuilder.buildCopy(Res, NewRes);
3830+
38203831
MI.eraseFromParent();
38213832
return Legalized;
38223833
}
@@ -7671,10 +7682,12 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
76717682
LLT Ty = Dst0Ty;
76727683
LLT BoolTy = Dst1Ty;
76737684

7685+
Register NewDst0 = MRI.cloneVirtualRegister(Dst0);
7686+
76747687
if (IsAdd)
7675-
MIRBuilder.buildAdd(Dst0, LHS, RHS);
7688+
MIRBuilder.buildAdd(NewDst0, LHS, RHS);
76767689
else
7677-
MIRBuilder.buildSub(Dst0, LHS, RHS);
7690+
MIRBuilder.buildSub(NewDst0, LHS, RHS);
76787691

76797692
// TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.
76807693

@@ -7687,12 +7700,15 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
76877700
// (LHS) if and only if the other operand (RHS) is (non-zero) positive,
76887701
// otherwise there will be overflow.
76897702
auto ResultLowerThanLHS =
7690-
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, Dst0, LHS);
7703+
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, NewDst0, LHS);
76917704
auto ConditionRHS = MIRBuilder.buildICmp(
76927705
IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);
76937706

76947707
MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);
7708+
7709+
MIRBuilder.buildCopy(Dst0, NewDst0);
76957710
MI.eraseFromParent();
7711+
76967712
return Legalized;
76977713
}
76987714

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,74 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
997997
return std::nullopt;
998998
}
999999

1000+
std::optional<SmallVector<APInt>>
1001+
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1002+
const MachineRegisterInfo &MRI) {
1003+
LLT Ty = MRI.getType(Op1);
1004+
if (Ty != MRI.getType(Op2))
1005+
return std::nullopt;
1006+
1007+
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1008+
Register RHS) -> std::optional<APInt> {
1009+
auto LHSCst = getIConstantVRegVal(LHS, MRI);
1010+
auto RHSCst = getIConstantVRegVal(RHS, MRI);
1011+
if (!LHSCst || !RHSCst)
1012+
return std::nullopt;
1013+
1014+
switch (Pred) {
1015+
case CmpInst::Predicate::ICMP_EQ:
1016+
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
1017+
case CmpInst::Predicate::ICMP_NE:
1018+
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
1019+
case CmpInst::Predicate::ICMP_UGT:
1020+
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
1021+
case CmpInst::Predicate::ICMP_UGE:
1022+
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
1023+
case CmpInst::Predicate::ICMP_ULT:
1024+
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
1025+
case CmpInst::Predicate::ICMP_ULE:
1026+
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
1027+
case CmpInst::Predicate::ICMP_SGT:
1028+
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
1029+
case CmpInst::Predicate::ICMP_SGE:
1030+
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
1031+
case CmpInst::Predicate::ICMP_SLT:
1032+
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
1033+
case CmpInst::Predicate::ICMP_SLE:
1034+
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
1035+
default:
1036+
return std::nullopt;
1037+
}
1038+
};
1039+
1040+
SmallVector<APInt> FoldedICmps;
1041+
1042+
if (Ty.isVector()) {
1043+
// Try to constant fold each element.
1044+
auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
1045+
auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
1046+
if (!BV1 || !BV2)
1047+
return std::nullopt;
1048+
assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
1049+
for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
1050+
if (auto MaybeFold =
1051+
TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
1052+
FoldedICmps.emplace_back(*MaybeFold);
1053+
continue;
1054+
}
1055+
return std::nullopt;
1056+
}
1057+
return FoldedICmps;
1058+
}
1059+
1060+
if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
1061+
FoldedICmps.emplace_back(*MaybeCst);
1062+
return FoldedICmps;
1063+
}
1064+
1065+
return std::nullopt;
1066+
}
1067+
10001068
bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
10011069
GISelKnownBits *KB) {
10021070
std::optional<DefinitionAndSourceRegister> DefSrcReg =

0 commit comments

Comments
 (0)