Skip to content

Commit 4eb7a81

Browse files
committed
[GlobalISel] Fold G_ICMP if possible
This patch tries to fold `G_ICMP` if possible.
1 parent 20e0bac commit 4eb7a81

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ std::optional<APFloat> ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy,
313313
std::optional<SmallVector<unsigned>>
314314
ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI);
315315

316+
std::optional<SmallVector<APInt>>
317+
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
318+
const MachineRegisterInfo &MRI);
319+
316320
/// Test if the given value is known to have exactly one bit set. This differs
317321
/// from computeKnownBits in that it doesn't necessarily determine which bit is
318322
/// set.

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
174174
switch (Opc) {
175175
default:
176176
break;
177+
case TargetOpcode::G_ICMP: {
178+
assert(SrcOps.size() == 3 && "Invalid sources");
179+
assert(DstOps.size() == 1 && "Invalid dsts");
180+
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
181+
182+
if (std::optional<SmallVector<APInt>> Cst =
183+
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
184+
SrcOps[2].getReg(), *getMRI())) {
185+
if (SrcTy.isVector())
186+
return buildBuildVectorConstant(DstOps[0], *Cst);
187+
return buildConstant(DstOps[0], Cst->front());
188+
}
189+
break;
190+
}
177191
case TargetOpcode::G_ADD:
178192
case TargetOpcode::G_PTR_ADD:
179193
case TargetOpcode::G_AND:

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,74 @@ llvm::ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI) {
996996
return std::nullopt;
997997
}
998998

999+
std::optional<SmallVector<APInt>>
1000+
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1001+
const MachineRegisterInfo &MRI) {
1002+
LLT Ty = MRI.getType(Op1);
1003+
if (Ty != MRI.getType(Op2))
1004+
return std::nullopt;
1005+
1006+
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1007+
Register RHS) -> std::optional<APInt> {
1008+
auto LHSCst = getIConstantVRegVal(LHS, MRI);
1009+
auto RHSCst = getIConstantVRegVal(RHS, MRI);
1010+
if (!LHSCst || !RHSCst)
1011+
return std::nullopt;
1012+
1013+
switch (Pred) {
1014+
case CmpInst::Predicate::ICMP_EQ:
1015+
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
1016+
case CmpInst::Predicate::ICMP_NE:
1017+
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
1018+
case CmpInst::Predicate::ICMP_UGT:
1019+
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
1020+
case CmpInst::Predicate::ICMP_UGE:
1021+
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
1022+
case CmpInst::Predicate::ICMP_ULT:
1023+
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
1024+
case CmpInst::Predicate::ICMP_ULE:
1025+
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
1026+
case CmpInst::Predicate::ICMP_SGT:
1027+
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
1028+
case CmpInst::Predicate::ICMP_SGE:
1029+
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
1030+
case CmpInst::Predicate::ICMP_SLT:
1031+
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
1032+
case CmpInst::Predicate::ICMP_SLE:
1033+
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
1034+
default:
1035+
return std::nullopt;
1036+
}
1037+
};
1038+
1039+
SmallVector<APInt> FoldedICmps;
1040+
1041+
if (Ty.isVector()) {
1042+
// Try to constant fold each element.
1043+
auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
1044+
auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
1045+
if (!BV1 || !BV2)
1046+
return std::nullopt;
1047+
assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
1048+
for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
1049+
if (auto MaybeFold =
1050+
TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
1051+
FoldedICmps.emplace_back(*MaybeFold);
1052+
continue;
1053+
}
1054+
return std::nullopt;
1055+
}
1056+
return FoldedICmps;
1057+
}
1058+
1059+
if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
1060+
FoldedICmps.emplace_back(*MaybeCst);
1061+
return FoldedICmps;
1062+
}
1063+
1064+
return std::nullopt;
1065+
}
1066+
9991067
bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
10001068
GISelKnownBits *KB) {
10011069
std::optional<DefinitionAndSourceRegister> DefSrcReg =

0 commit comments

Comments
 (0)