Skip to content

Commit 0926d94

Browse files
authored
[GlobalISel] Take the result size into account when const folding icmp (#134365)
The current implementation always creates a 1 bit constant for the result of the `G_ICMP`, which will cause issues if the destination register size is larger than that. With asserts enabled, it will cause a crash in `buildConstant`: ``` llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp:322: virtual MachineInstrBuilder llvm::MachineIRBuilder::buildConstant(const DstOp &, const ConstantInt &): Assertion `EltTy.getScalarSizeInBits() == Val.getBitWidth() && "creating constant with the wrong size"' failed. ```
1 parent 1ba89ad commit 0926d94

File tree

4 files changed

+79
-20
lines changed

4 files changed

+79
-20
lines changed

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
325325

326326
std::optional<SmallVector<APInt>>
327327
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
328+
unsigned DstScalarSizeInBits, unsigned ExtOp,
328329
const MachineRegisterInfo &MRI);
329330

330331
/// Test if the given value is known to have exactly one bit set. This differs

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,12 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
189189
assert(SrcOps.size() == 3 && "Invalid sources");
190190
assert(DstOps.size() == 1 && "Invalid dsts");
191191
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
192+
LLT DstTy = DstOps[0].getLLTTy(*getMRI());
193+
auto BoolExtOp = getBoolExtOp(SrcTy.isVector(), false);
192194

193-
if (std::optional<SmallVector<APInt>> Cst =
194-
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
195-
SrcOps[2].getReg(), *getMRI())) {
195+
if (std::optional<SmallVector<APInt>> Cst = ConstantFoldICmp(
196+
SrcOps[0].getPredicate(), SrcOps[1].getReg(), SrcOps[2].getReg(),
197+
DstTy.getScalarSizeInBits(), BoolExtOp, *getMRI())) {
196198
if (SrcTy.isVector())
197199
return buildBuildVectorConstant(DstOps[0], *Cst);
198200
return buildConstant(DstOps[0], Cst->front());

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,39 +1027,50 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
10271027

10281028
std::optional<SmallVector<APInt>>
10291029
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1030+
unsigned DstScalarSizeInBits, unsigned ExtOp,
10301031
const MachineRegisterInfo &MRI) {
1031-
LLT Ty = MRI.getType(Op1);
1032-
if (Ty != MRI.getType(Op2))
1033-
return std::nullopt;
1032+
assert(ExtOp == TargetOpcode::G_SEXT || ExtOp == TargetOpcode::G_ZEXT ||
1033+
ExtOp == TargetOpcode::G_ANYEXT);
10341034

1035-
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1036-
Register RHS) -> std::optional<APInt> {
1037-
auto LHSCst = getIConstantVRegVal(LHS, MRI);
1035+
const LLT Ty = MRI.getType(Op1);
1036+
1037+
auto GetICmpResultCst = [&](bool IsTrue) {
1038+
if (IsTrue)
1039+
return ExtOp == TargetOpcode::G_SEXT
1040+
? APInt::getAllOnes(DstScalarSizeInBits)
1041+
: APInt::getOneBitSet(DstScalarSizeInBits, 0);
1042+
return APInt::getZero(DstScalarSizeInBits);
1043+
};
1044+
1045+
auto TryFoldScalar = [&](Register LHS, Register RHS) -> std::optional<APInt> {
10381046
auto RHSCst = getIConstantVRegVal(RHS, MRI);
1039-
if (!LHSCst || !RHSCst)
1047+
if (!RHSCst)
1048+
return std::nullopt;
1049+
auto LHSCst = getIConstantVRegVal(LHS, MRI);
1050+
if (!LHSCst)
10401051
return std::nullopt;
10411052

10421053
switch (Pred) {
10431054
case CmpInst::Predicate::ICMP_EQ:
1044-
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
1055+
return GetICmpResultCst(LHSCst->eq(*RHSCst));
10451056
case CmpInst::Predicate::ICMP_NE:
1046-
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
1057+
return GetICmpResultCst(LHSCst->ne(*RHSCst));
10471058
case CmpInst::Predicate::ICMP_UGT:
1048-
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
1059+
return GetICmpResultCst(LHSCst->ugt(*RHSCst));
10491060
case CmpInst::Predicate::ICMP_UGE:
1050-
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
1061+
return GetICmpResultCst(LHSCst->uge(*RHSCst));
10511062
case CmpInst::Predicate::ICMP_ULT:
1052-
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
1063+
return GetICmpResultCst(LHSCst->ult(*RHSCst));
10531064
case CmpInst::Predicate::ICMP_ULE:
1054-
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
1065+
return GetICmpResultCst(LHSCst->ule(*RHSCst));
10551066
case CmpInst::Predicate::ICMP_SGT:
1056-
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
1067+
return GetICmpResultCst(LHSCst->sgt(*RHSCst));
10571068
case CmpInst::Predicate::ICMP_SGE:
1058-
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
1069+
return GetICmpResultCst(LHSCst->sge(*RHSCst));
10591070
case CmpInst::Predicate::ICMP_SLT:
1060-
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
1071+
return GetICmpResultCst(LHSCst->slt(*RHSCst));
10611072
case CmpInst::Predicate::ICMP_SLE:
1062-
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
1073+
return GetICmpResultCst(LHSCst->sle(*RHSCst));
10631074
default:
10641075
return std::nullopt;
10651076
}

llvm/unittests/CodeGen/GlobalISel/CSETest.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,18 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
500500
EXPECT_TRUE(I->getOperand(1).getCImm()->getZExtValue());
501501
}
502502

503+
{
504+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, One);
505+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
506+
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 1);
507+
}
508+
509+
{
510+
auto I = CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, s32, One, Two);
511+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_CONSTANT);
512+
EXPECT_EQ(I->getOperand(1).getCImm()->getZExtValue(), 0);
513+
}
514+
503515
LLT VecTy = LLT::fixed_vector(2, s32);
504516
LLT DstTy = LLT::fixed_vector(2, s1);
505517
auto Three = CSEB.buildConstant(s32, 3);
@@ -508,6 +520,8 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
508520
auto OneTwo = CSEB.buildBuildVector(VecTy, {One.getReg(0), Two.getReg(0)});
509521
auto TwoThree =
510522
CSEB.buildBuildVector(VecTy, {Two.getReg(0), Three.getReg(0)});
523+
auto OneThree =
524+
CSEB.buildBuildVector(VecTy, {One.getReg(0), Three.getReg(0)});
511525
auto MinusOneOne =
512526
CSEB.buildBuildVector(VecTy, {MinusOne.getReg(0), MinusOne.getReg(0)});
513527
auto MinusOneTwo =
@@ -547,6 +561,36 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
547561
// ICMP_SLE
548562
CSEB.buildICmp(CmpInst::Predicate::ICMP_SLE, DstTy, MinusOneTwo, MinusOneOne);
549563

564+
{
565+
auto I =
566+
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneOne, TwoThree);
567+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
568+
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
569+
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
570+
EXPECT_EQ(HiCst.getSExtValue(), 0);
571+
EXPECT_EQ(LoCst.getSExtValue(), 0);
572+
}
573+
574+
{
575+
auto I =
576+
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, OneThree, TwoThree);
577+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
578+
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
579+
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
580+
EXPECT_EQ(HiCst.getSExtValue(), 0);
581+
EXPECT_EQ(LoCst.getSExtValue(), -1);
582+
}
583+
584+
{
585+
auto I =
586+
CSEB.buildICmp(CmpInst::Predicate::ICMP_EQ, VecTy, TwoThree, TwoThree);
587+
EXPECT_TRUE(I->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
588+
const APInt HiCst = *getIConstantVRegVal(I->getOperand(1).getReg(), *MRI);
589+
const APInt LoCst = *getIConstantVRegVal(I->getOperand(2).getReg(), *MRI);
590+
EXPECT_EQ(HiCst.getSExtValue(), -1);
591+
EXPECT_EQ(LoCst.getSExtValue(), -1);
592+
}
593+
550594
auto CheckStr = R"(
551595
; CHECK: [[One:%[0-9]+]]:_(s32) = G_CONSTANT i32 1
552596
; CHECK: [[Two:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
@@ -558,6 +602,7 @@ TEST_F(AArch64GISelMITest, TestConstantFoldICMP) {
558602
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[One]]:_(s32)
559603
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Two]]:_(s32)
560604
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[Two]]:_(s32), [[Three]]:_(s32)
605+
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[One]]:_(s32), [[Three]]:_(s32)
561606
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusOne]]:_(s32)
562607
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusOne]]:_(s32), [[MinusTwo]]:_(s32)
563608
; CHECK: {{%[0-9]+}}:_(<2 x s32>) = G_BUILD_VECTOR [[MinusTwo]]:_(s32), [[MinusThree]]:_(s32)

0 commit comments

Comments
 (0)