Skip to content

[GlobalISel] Take the result size into account when const folding icmp #134365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,

std::optional<SmallVector<APInt>>
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
unsigned DstScalarSizeInBits, unsigned ExtOp,
const MachineRegisterInfo &MRI);

/// Test if the given value is known to have exactly one bit set. This differs
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
assert(SrcOps.size() == 3 && "Invalid sources");
assert(DstOps.size() == 1 && "Invalid dsts");
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
LLT DstTy = DstOps[0].getLLTTy(*getMRI());
auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);

if (std::optional<SmallVector<APInt>> Cst =
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
SrcOps[2].getReg(), *getMRI())) {
if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
if (SrcTy.isVector())
return buildBuildVectorConstant(DstOps[0], *Cst);
return buildConstant(DstOps[0], Cst->front());
Expand Down
45 changes: 28 additions & 17 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,39 +1027,50 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,

std::optional<SmallVector<APInt>>
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
unsigned DstScalarSizeInBits, unsigned ExtOp,
const MachineRegisterInfo &MRI) {
LLT Ty = MRI.getType(Op1);
if (Ty != MRI.getType(Op2))
return std::nullopt;
assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
ExtOp == TargetOpcode::G_ANYEXT);

auto TryFoldScalar = [&MRI, Pred](Register LHS,
Register RHS) -> std::optional<APInt> {
auto LHSCst = getIConstantVRegVal(LHS, MRI);
const LLT Ty = MRI.getType(Op1);

auto GetICmpResultCst = [&](bool IsTrue) {
if (IsTrue)
return ExtOp == TargetOpcode::G_SEXT
? APInt::getAllOnes(DstScalarSizeInBits)
: APInt::getOneBitSet(DstScalarSizeInBits, 0);
return APInt::getZero(DstScalarSizeInBits);
};

auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
auto RHSCst = getIConstantVRegVal(RHS, MRI);
if (!LHSCst || !RHSCst)
if (!RHSCst)
return std::nullopt;
auto LHSCst = getIConstantVRegVal(LHS, MRI);
if (!LHSCst)
return std::nullopt;

switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
return GetICmpResultCst(LHSCst->eq(*RHSCst));
case CmpInst::Predicate::ICMP_NE:
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
return GetICmpResultCst(LHSCst->ne(*RHSCst));
case CmpInst::Predicate::ICMP_UGT:
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
return GetICmpResultCst(LHSCst->ugt(*RHSCst));
case CmpInst::Predicate::ICMP_UGE:
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
return GetICmpResultCst(LHSCst->uge(*RHSCst));
case CmpInst::Predicate::ICMP_ULT:
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
return GetICmpResultCst(LHSCst->ult(*RHSCst));
case CmpInst::Predicate::ICMP_ULE:
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
return GetICmpResultCst(LHSCst->ule(*RHSCst));
case CmpInst::Predicate::ICMP_SGT:
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
return GetICmpResultCst(LHSCst->sgt(*RHSCst));
case CmpInst::Predicate::ICMP_SGE:
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
return GetICmpResultCst(LHSCst->sge(*RHSCst));
case CmpInst::Predicate::ICMP_SLT:
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
return GetICmpResultCst(LHSCst->slt(*RHSCst));
case CmpInst::Predicate::ICMP_SLE:
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
return GetICmpResultCst(LHSCst->sle(*RHSCst));
default:
return std::nullopt;
}
Expand Down
45 changes: 45 additions & 0 deletions llvm/unittests/CodeGen/GlobalISel/CSETest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
}

{
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1);
}

{
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check a vector case?

LLT VecTy = LLT::fixed_vector(2, s32);
LLT DstTy = LLT::fixed_vector(2, s1);
auto Three = CSEB.buildConstant(s32, 3);
Expand All @@ -508,6 +520,8 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
auto OneTwo = CSEB.buildBuildVector(VecTy, {One.getReg(0), Two.getReg(0)});
auto TwoThree =
CSEB.buildBuildVector(VecTy, {Two.getReg(0), Three.getReg(0)});
auto OneThree =
CSEB.buildBuildVector(VecTy, {One.getReg(0), Three.getReg(0)});
auto MinusOneOne =
CSEB.buildBuildVector(VecTy, {MinusOne.getReg(0), MinusOne.getReg(0)});
auto MinusOneTwo =
Expand Down Expand Up @@ -547,6 +561,36 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
// ICMP_SLE
CSEB.buildICmp(CmpInst::Predicate::ICMP_SLE, DstTy, MinusOneTwo, MinusOneOne);

{
auto I =
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneOne, TwoThree);
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
EXPECT_EQ(HiCst.getSExtValue(), 0);
EXPECT_EQ(LoCst.getSExtValue(), 0);
}

{
auto I =
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneThree, TwoThree);
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
EXPECT_EQ(HiCst.getSExtValue(), 0);
EXPECT_EQ(LoCst.getSExtValue(), -1);
}

{
auto I =
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, TwoThree, TwoThree);
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
EXPECT_EQ(HiCst.getSExtValue(), -1);
EXPECT_EQ(LoCst.getSExtValue(), -1);
}

auto CheckStr = R"(
; CHECK: [[One:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
; CHECK: [[Two:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
Expand All @@ -558,6 +602,7 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[One]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Two]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[Two]]:_(s32), [[Three]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Three]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusOne]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusTwo]]:_(s32)
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusTwo]]:_(s32), [[MinusThree]]:_(s32)
Expand Down
Loading