Skip to content

Commit 99807b1

Browse files
committed
[GlobalISel] Take the result size into account when const folding icmp
1 parent 4da5e9d commit 99807b1

File tree

4 files changed

+39
-18
lines changed

4 files changed

+39
-18
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
325325

326326
std::optional<SmallVector<APInt>>
327327
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
328+
unsigned DstScalarSizeInBits, unsigned ExtOp,
328329
const MachineRegisterInfo &MRI);
329330

330331
/// Test if the given value is known to have exactly one bit set. This differs

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
189189
assert(SrcOps.size() == 3 && "Invalid sources");
190190
assert(DstOps.size() == 1 && "Invalid dsts");
191191
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
192+
LLT DstTy = DstOps[0].getLLTTy(*getMRI());
193+
auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
192194

193-
if (std::optional<SmallVector<APInt>> Cst =
194-
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
195-
SrcOps[2].getReg(), *getMRI())) {
195+
if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
196+
SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
197+
DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
196198
if (SrcTy.isVector())
197199
return buildBuildVectorConstant(DstOps[0], *Cst);
198200
return buildConstant(DstOps[0], Cst->front());

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,39 +1027,45 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
10271027

10281028
std::optional<SmallVector<APInt>>
10291029
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1030+
unsigned DstScalarSizeInBits, unsigned ExtOp,
10301031
const MachineRegisterInfo &MRI) {
1031-
LLT Ty = MRI.getType(Op1);
1032-
if (Ty != MRI.getType(Op2))
1033-
return std::nullopt;
1032+
assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
1033+
ExtOp == TargetOpcode::G_ANYEXT);
1034+
1035+
const LLT Ty = MRI.getType(Op1);
10341036

1035-
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1036-
Register RHS) -> std::optional<APInt> {
1037+
auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
10371038
auto LHSCst = getIConstantVRegVal(LHS, MRI);
10381039
auto RHSCst = getIConstantVRegVal(RHS, MRI);
10391040
if (!LHSCst || !RHSCst)
10401041
return std::nullopt;
10411042

1043+
const APInt FalseCst = APInt::getZero(DstScalarSizeInBits);
1044+
const APInt TrueCst = (ExtOp == TargetOpcode::G_SEXT)
1045+
? APInt::getAllOnes(DstScalarSizeInBits)
1046+
: APInt::getOneBitSet(DstScalarSizeInBits, 0);
1047+
10421048
switch (Pred) {
10431049
case CmpInst::Predicate::ICMP_EQ:
1044-
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
1050+
return LHSCst->eq(*RHSCst) ? TrueCst : FalseCst;
10451051
case CmpInst::Predicate::ICMP_NE:
1046-
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
1052+
return LHSCst->ne(*RHSCst) ? TrueCst : FalseCst;
10471053
case CmpInst::Predicate::ICMP_UGT:
1048-
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
1054+
return LHSCst->ugt(*RHSCst) ? TrueCst : FalseCst;
10491055
case CmpInst::Predicate::ICMP_UGE:
1050-
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
1056+
return LHSCst->uge(*RHSCst) ? TrueCst : FalseCst;
10511057
case CmpInst::Predicate::ICMP_ULT:
1052-
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
1058+
return LHSCst->ult(*RHSCst) ? TrueCst : FalseCst;
10531059
case CmpInst::Predicate::ICMP_ULE:
1054-
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
1060+
return LHSCst->ule(*RHSCst) ? TrueCst : FalseCst;
10551061
case CmpInst::Predicate::ICMP_SGT:
1056-
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
1062+
return LHSCst->sgt(*RHSCst) ? TrueCst : FalseCst;
10571063
case CmpInst::Predicate::ICMP_SGE:
1058-
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
1064+
return LHSCst->sge(*RHSCst) ? TrueCst : FalseCst;
10591065
case CmpInst::Predicate::ICMP_SLT:
1060-
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
1066+
return LHSCst->slt(*RHSCst) ? TrueCst : FalseCst;
10611067
case CmpInst::Predicate::ICMP_SLE:
1062-
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
1068+
return LHSCst->sle(*RHSCst) ? TrueCst : FalseCst;
10631069
default:
10641070
return std::nullopt;
10651071
}

llvm/unittests/CodeGen/GlobalISel/CSETest.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
500500
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
501501
}
502502

503+
{
504+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
505+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
506+
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1);
507+
}
508+
509+
{
510+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
511+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
512+
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0);
513+
}
514+
503515
LLT VecTy = LLT::fixed_vector(2, s32);
504516
LLT DstTy = LLT::fixed_vector(2, s1);
505517
auto Three = CSEB.buildConstant(s32, 3);

0 commit comments

Comments
 (0)