Skip to content

Commit bf7aef8

Browse files
committed
[InstCombine] Fix behavior for (fmul (sitfp x), 0)
Bug was introduced in #82555 We where missing check that the constant was non-zero for signed + mul transform.
1 parent 6451085 commit bf7aef8

File tree

4 files changed

+151
-12
lines changed

4 files changed

+151
-12
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,15 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
330330
/// This helper class is used to match constant scalars, vector splats,
331331
/// and fixed width vectors that satisfy a specified predicate.
332332
/// For fixed width vector constants, undefined elements are ignored.
333-
template <typename Predicate, typename ConstantVal>
333+
template <typename Predicate, typename ConstantVal, bool AllowUndef>
334334
struct cstval_pred_ty : public Predicate {
335335
template <typename ITy> bool match(ITy *V) {
336336
if (const auto *CV = dyn_cast<ConstantVal>(V))
337337
return this->isValue(CV->getValue());
338338
if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
339339
if (const auto *C = dyn_cast<Constant>(V)) {
340-
if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
340+
if (const auto *CV =
341+
dyn_cast_or_null<ConstantVal>(C->getSplatValue(AllowUndef)))
341342
return this->isValue(CV->getValue());
342343

343344
// Number of elements of a scalable vector unknown at compile time
@@ -353,8 +354,11 @@ struct cstval_pred_ty : public Predicate {
353354
Constant *Elt = C->getAggregateElement(i);
354355
if (!Elt)
355356
return false;
356-
if (isa<UndefValue>(Elt))
357+
if (isa<UndefValue>(Elt)) {
358+
if (!AllowUndef)
359+
return false;
357360
continue;
361+
}
358362
auto *CV = dyn_cast<ConstantVal>(Elt);
359363
if (!CV || !this->isValue(CV->getValue()))
360364
return false;
@@ -368,19 +372,20 @@ struct cstval_pred_ty : public Predicate {
368372
};
369373

370374
/// specialization of cstval_pred_ty for ConstantInt
371-
template <typename Predicate>
372-
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
375+
template <typename Predicate, bool AllowUndef = true>
376+
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndef>;
373377

374378
/// specialization of cstval_pred_ty for ConstantFP
375-
template <typename Predicate>
376-
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
379+
template <typename Predicate, bool AllowUndef = true>
380+
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndef>;
377381

378382
/// This helper class is used to match scalar and vector constants that
379383
/// satisfy a specified predicate, and bind them to an APInt.
380384
template <typename Predicate> struct api_pred_ty : public Predicate {
381385
const APInt *&Res;
382-
383-
api_pred_ty(const APInt *&R) : Res(R) {}
386+
bool AllowUndef;
387+
api_pred_ty(const APInt *&R, bool AllowUndef = true)
388+
: Res(R), AllowUndef(AllowUndef) {}
384389

385390
template <typename ITy> bool match(ITy *V) {
386391
if (const auto *CI = dyn_cast<ConstantInt>(V))
@@ -390,7 +395,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
390395
}
391396
if (V->getType()->isVectorTy())
392397
if (const auto *C = dyn_cast<Constant>(V))
393-
if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
398+
if (auto *CI =
399+
dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef)))
394400
if (this->isValue(CI->getValue())) {
395401
Res = &CI->getValue();
396402
return true;
@@ -544,6 +550,30 @@ struct is_zero {
544550
/// For vectors, this includes constants with undefined elements.
545551
inline is_zero m_Zero() { return is_zero(); }
546552

553+
struct is_non_zero {
554+
bool isValue(const APInt &C) { return !C.isZero(); }
555+
};
556+
557+
/// Match any constant s.t all elements are non-zero. For a scalar, this is the
558+
/// same as !m_Zero. For vectors is ensures that !m_Zero holds for all elements.
559+
/// This does not include undefined elements.
560+
inline cst_pred_ty<is_non_zero, false> m_NonZero() {
561+
return cst_pred_ty<is_non_zero, /*AllowUndef=*/false>();
562+
}
563+
inline api_pred_ty<is_non_zero> m_NonZero(const APInt *&V) {
564+
return api_pred_ty<is_non_zero>(V, /*AllowUndef=*/false);
565+
}
566+
567+
/// Match any constant s.t all elements are non-zero. For a scalar, this is the
568+
/// same as !m_Zero. For vectors is ensures that !m_Zero holds for all elements.
569+
/// This includes undefined elements.
570+
inline cst_pred_ty<is_non_zero> m_NonZeroAllowUndef() {
571+
return cst_pred_ty<is_non_zero>();
572+
}
573+
inline api_pred_ty<is_non_zero> m_NonZeroAllowUndef(const APInt *&V) {
574+
return V;
575+
}
576+
547577
struct is_power2 {
548578
bool isValue(const APInt &C) { return C.isPowerOf2(); }
549579
};

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,11 @@ Instruction *InstCombinerImpl::foldFBinOpOfIntCastsFromSign(
14911491
Op1IntC, FPTy, DL) != Op1FpC)
14921492
return nullptr;
14931493

1494+
// Signed + Mul req non-zero
1495+
if (OpsFromSigned && BO.getOpcode() == Instruction::FMul &&
1496+
!match(Op1IntC, m_NonZero()))
1497+
return nullptr;
1498+
14941499
// First try to keep sign of cast the same.
14951500
IntOps[1] = Op1IntC;
14961501
}

llvm/test/Transforms/InstCombine/binop-itofp.ll

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,11 @@ define float @test_ui_add_with_signed_constant(i32 %shr.i) {
10101010
define float @missed_nonzero_check_on_constant_for_si_fmul(i1 %c, i1 %.b, ptr %g_2345) {
10111011
; CHECK-LABEL: @missed_nonzero_check_on_constant_for_si_fmul(
10121012
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C:%.*]], i32 65529, i32 53264
1013+
; CHECK-NEXT: [[CONV_I:%.*]] = trunc i32 [[SEL]] to i16
1014+
; CHECK-NEXT: [[CONV1_I:%.*]] = sitofp i16 [[CONV_I]] to float
1015+
; CHECK-NEXT: [[MUL3_I_I:%.*]] = fmul float [[CONV1_I]], 0.000000e+00
10131016
; CHECK-NEXT: store i32 [[SEL]], ptr [[G_2345:%.*]], align 4
1014-
; CHECK-NEXT: ret float 0.000000e+00
1017+
; CHECK-NEXT: ret float [[MUL3_I_I]]
10151018
;
10161019
%sel = select i1 %c, i32 65529, i32 53264
10171020
%conv.i = trunc i32 %sel to i16
@@ -1024,8 +1027,13 @@ define float @missed_nonzero_check_on_constant_for_si_fmul(i1 %c, i1 %.b, ptr %g
10241027
define <2 x float> @missed_nonzero_check_on_constant_for_si_fmul_vec(i1 %c, i1 %.b, ptr %g_2345) {
10251028
; CHECK-LABEL: @missed_nonzero_check_on_constant_for_si_fmul_vec(
10261029
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C:%.*]], i32 65529, i32 53264
1030+
; CHECK-NEXT: [[CONV_I_S:%.*]] = trunc i32 [[SEL]] to i16
1031+
; CHECK-NEXT: [[CONV_I_V:%.*]] = insertelement <2 x i16> poison, i16 [[CONV_I_S]], i64 0
1032+
; CHECK-NEXT: [[CONV_I:%.*]] = shufflevector <2 x i16> [[CONV_I_V]], <2 x i16> poison, <2 x i32> zeroinitializer
1033+
; CHECK-NEXT: [[CONV1_I:%.*]] = sitofp <2 x i16> [[CONV_I]] to <2 x float>
1034+
; CHECK-NEXT: [[MUL3_I_I:%.*]] = fmul <2 x float> [[CONV1_I]], zeroinitializer
10271035
; CHECK-NEXT: store i32 [[SEL]], ptr [[G_2345:%.*]], align 4
1028-
; CHECK-NEXT: ret <2 x float> zeroinitializer
1036+
; CHECK-NEXT: ret <2 x float> [[MUL3_I_I]]
10291037
;
10301038
%sel = select i1 %c, i32 65529, i32 53264
10311039
%conv.i.s = trunc i32 %sel to i16

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,102 @@ TEST_F(PatternMatchTest, Power2) {
614614
EXPECT_TRUE(m_NegatedPower2OrZero().match(CZero));
615615
}
616616

617+
TEST_F(PatternMatchTest, NonZero) {
618+
Type *I8Ty = IRB.getInt8Ty();
619+
620+
EXPECT_FALSE(m_NonZero().match(ConstantInt::get(I8Ty, 0)));
621+
EXPECT_TRUE(m_NonZero().match(ConstantInt::get(I8Ty, 1)));
622+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantInt::get(I8Ty, 0)));
623+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantInt::get(I8Ty, 1)));
624+
625+
EXPECT_FALSE(m_NonZero().match(UndefValue::get(I8Ty)));
626+
EXPECT_FALSE(m_NonZero().match(PoisonValue::get(I8Ty)));
627+
EXPECT_FALSE(m_NonZeroAllowUndef().match(UndefValue::get(I8Ty)));
628+
EXPECT_FALSE(m_NonZeroAllowUndef().match(PoisonValue::get(I8Ty)));
629+
630+
{
631+
SmallVector<Constant *, 2> VecElemIdxs;
632+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
633+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
634+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
635+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
636+
}
637+
638+
{
639+
SmallVector<Constant *, 2> VecElemIdxs;
640+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
641+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
642+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
643+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
644+
}
645+
646+
{
647+
SmallVector<Constant *, 2> VecElemIdxs;
648+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
649+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
650+
EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
651+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
652+
}
653+
654+
{
655+
SmallVector<Constant *, 2> VecElemIdxs;
656+
VecElemIdxs.push_back(UndefValue::get(I8Ty));
657+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
658+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
659+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
660+
}
661+
662+
{
663+
SmallVector<Constant *, 3> VecElemIdxs;
664+
VecElemIdxs.push_back(UndefValue::get(I8Ty));
665+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
666+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 3));
667+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
668+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
669+
}
670+
671+
{
672+
SmallVector<Constant *, 2> VecElemIdxs;
673+
VecElemIdxs.push_back(PoisonValue::get(I8Ty));
674+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
675+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
676+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
677+
}
678+
679+
{
680+
SmallVector<Constant *, 3> VecElemIdxs;
681+
VecElemIdxs.push_back(PoisonValue::get(I8Ty));
682+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
683+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 3));
684+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
685+
EXPECT_TRUE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
686+
}
687+
688+
{
689+
SmallVector<Constant *, 2> VecElemIdxs;
690+
VecElemIdxs.push_back(UndefValue::get(I8Ty));
691+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
692+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
693+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
694+
}
695+
696+
{
697+
SmallVector<Constant *, 2> VecElemIdxs;
698+
VecElemIdxs.push_back(PoisonValue::get(I8Ty));
699+
VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
700+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
701+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
702+
}
703+
704+
{
705+
SmallVector<Constant *, 2> VecElemIdxs;
706+
VecElemIdxs.push_back(PoisonValue::get(I8Ty));
707+
VecElemIdxs.push_back(UndefValue::get(I8Ty));
708+
EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
709+
EXPECT_FALSE(m_NonZeroAllowUndef().match(ConstantVector::get(VecElemIdxs)));
710+
}
711+
}
712+
617713
TEST_F(PatternMatchTest, Not) {
618714
Value *C1 = IRB.getInt32(1);
619715
Value *C2 = IRB.getInt32(2);

0 commit comments

Comments
 (0)