Skip to content

Commit 464fed3

Browse files
committed
[GlobalISel] Take the result size into account when const folding icmp
1 parent 4da5e9d commit 464fed3

File tree

4 files changed

+36
-15
lines changed

4 files changed

+36
-15
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
325325

326326
std::optional<SmallVector<APInt>>
327327
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
328+
unsigned DstSizeInBits, unsigned ExtOp,
328329
const MachineRegisterInfo &MRI);
329330

330331
/// Test if the given value is known to have exactly one bit set. This differs

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
189189
assert(SrcOps.size() == 3 && "Invalid sources");
190190
assert(DstOps.size() == 1 && "Invalid dsts");
191191
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
192+
LLT DstTy = DstOps[0].getLLTTy(*getMRI());
193+
auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
192194

193-
if (std::optional<SmallVector<APInt>> Cst =
194-
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
195-
SrcOps[2].getReg(), *getMRI())) {
195+
if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
196+
SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
197+
DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
196198
if (SrcTy.isVector())
197199
return buildBuildVectorConstant(DstOps[0], *Cst);
198200
return buildConstant(DstOps[0], Cst->front());

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,39 +1027,45 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
10271027

10281028
std::optional<SmallVector<APInt>>
10291029
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1030+
unsigned DstSizeInBits, unsigned ExtOp,
10301031
const MachineRegisterInfo &MRI) {
1032+
assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
1033+
ExtOp == TargetOpcode::G_ANYEXT);
1034+
10311035
LLT Ty = MRI.getType(Op1);
10321036
if (Ty != MRI.getType(Op2))
10331037
return std::nullopt;
10341038

1035-
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1036-
Register RHS) -> std::optional<APInt> {
1039+
const int64_t Sign = ExtOp == TargetOpcode::G_SEXT ? -1 : 1;
1040+
1041+
auto TryFoldScalar = [&MRI, Pred, DstSizeInBits, Sign](
1042+
Register LHS, Register RHS) -> std::optional<APInt> {
10371043
auto LHSCst = getIConstantVRegVal(LHS, MRI);
10381044
auto RHSCst = getIConstantVRegVal(RHS, MRI);
10391045
if (!LHSCst || !RHSCst)
10401046
return std::nullopt;
10411047

10421048
switch (Pred) {
10431049
case CmpInst::Predicate::ICMP_EQ:
1044-
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
1050+
return APInt(DstSizeInBits, Sign * LHSCst->eq(*RHSCst), true);
10451051
case CmpInst::Predicate::ICMP_NE:
1046-
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
1052+
return APInt(DstSizeInBits, Sign * LHSCst->ne(*RHSCst), true);
10471053
case CmpInst::Predicate::ICMP_UGT:
1048-
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
1054+
return APInt(DstSizeInBits, Sign * LHSCst->ugt(*RHSCst), true);
10491055
case CmpInst::Predicate::ICMP_UGE:
1050-
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
1056+
return APInt(DstSizeInBits, Sign * LHSCst->uge(*RHSCst), true);
10511057
case CmpInst::Predicate::ICMP_ULT:
1052-
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
1058+
return APInt(DstSizeInBits, Sign * LHSCst->ult(*RHSCst), true);
10531059
case CmpInst::Predicate::ICMP_ULE:
1054-
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
1060+
return APInt(DstSizeInBits, Sign * LHSCst->ule(*RHSCst), true);
10551061
case CmpInst::Predicate::ICMP_SGT:
1056-
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
1062+
return APInt(DstSizeInBits, Sign * LHSCst->sgt(*RHSCst), true);
10571063
case CmpInst::Predicate::ICMP_SGE:
1058-
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
1064+
return APInt(DstSizeInBits, Sign * LHSCst->sge(*RHSCst), true);
10591065
case CmpInst::Predicate::ICMP_SLT:
1060-
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
1066+
return APInt(DstSizeInBits, Sign * LHSCst->slt(*RHSCst), true);
10611067
case CmpInst::Predicate::ICMP_SLE:
1062-
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
1068+
return APInt(DstSizeInBits, Sign * LHSCst->sle(*RHSCst), true);
10631069
default:
10641070
return std::nullopt;
10651071
}

llvm/unittests/CodeGen/GlobalISel/CSETest.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
500500
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
501501
}
502502

503+
{
504+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
505+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
506+
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
507+
}
508+
509+
{
510+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
511+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
512+
EXPECT_FALSE(I->getOperand(1).getCImm()->getZExtValue());
513+
}
514+
503515
LLT VecTy = LLT::fixed_vector(2, s32);
504516
LLT DstTy = LLT::fixed_vector(2, s1);
505517
auto Three = CSEB.buildConstant(s32, 3);

0 commit comments

Comments
 (0)