Skip to content

[InstCombine] Fix poison propagation in select of bitwise fold #89701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {

/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, poison elements are ignored.
template <typename Predicate, typename ConstantVal>
/// For fixed width vector constants, poison elements are ignored if AllowPoison
/// is true.
template <typename Predicate, typename ConstantVal, bool AllowPoison>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
Expand All @@ -374,7 +375,7 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
if (isa<PoisonValue>(Elt))
if (AllowPoison && isa<PoisonValue>(Elt))
continue;
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
Expand All @@ -389,12 +390,13 @@ struct cstval_pred_ty : public Predicate {
};

/// specialization of cstval_pred_ty for ConstantInt
template <typename Predicate>
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
template <typename Predicate, bool AllowPoison = true>
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowPoison>;

/// specialization of cstval_pred_ty for ConstantFP
template <typename Predicate>
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP,
/*AllowPoison=*/true>;

/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
Expand Down Expand Up @@ -484,6 +486,10 @@ inline cst_pred_ty<is_all_ones> m_AllOnes() {
return cst_pred_ty<is_all_ones>();
}

inline cst_pred_ty<is_all_ones, false> m_AllOnesForbidPoison() {
return cst_pred_ty<is_all_ones, false>();
}

struct is_maxsignedvalue {
bool isValue(const APInt &C) { return C.isMaxSignedValue(); }
};
Expand Down Expand Up @@ -2596,6 +2602,13 @@ m_Not(const ValTy &V) {
return m_c_Xor(m_AllOnes(), V);
}

template <typename ValTy>
inline BinaryOp_match<cst_pred_ty<is_all_ones, false>, ValTy, Instruction::Xor,
true>
m_NotForbidPoison(const ValTy &V) {
return m_c_Xor(m_AllOnesForbidPoison(), V);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to have unit tests for anything that has lit test coverage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Except for exhaustive unit tests.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err I was confusing, not the m_NotForbidPoison, but for the AllowPoison template arg in the cstval_pred_ty unittests.

Either way though.


/// Matches an SMin with LHS and RHS in either order.
template <typename LHS, typename RHS>
inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1722,11 +1722,11 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);

if (NotMask == NotInner) {
return match(FalseVal,
m_c_BinOp(OuterOpc, m_Not(matchInner), m_Specific(CmpRHS)));
return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
m_Specific(CmpRHS)));
} else if (NotMask == NotRHS) {
return match(FalseVal,
m_c_BinOp(OuterOpc, matchInner, m_Not(m_Specific(CmpRHS))));
return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
m_NotForbidPoison(m_Specific(CmpRHS))));
} else {
return match(FalseVal,
m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
Expand Down
11 changes: 7 additions & 4 deletions llvm/test/Transforms/InstCombine/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3830,14 +3830,17 @@ entry:
ret i32 %cond
}

; FIXME: This is a miscompile.
define <2 x i32> @src_and_eq_C_xor_OrAndNotC_vec_poison(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2) {
; CHECK-LABEL: @src_and_eq_C_xor_OrAndNotC_vec_poison(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[OR:%.*]] = or <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[NOT:%.*]] = xor <2 x i32> [[TMP2:%.*]], <i32 -1, i32 poison>
; CHECK-NEXT: [[AND:%.*]] = and <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[AND]], [[TMP2:%.*]]
; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i32> [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[OR:%.*]] = or <2 x i32> [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[NOT:%.*]] = xor <2 x i32> [[TMP2]], <i32 -1, i32 poison>
; CHECK-NEXT: [[AND1:%.*]] = and <2 x i32> [[OR]], [[NOT]]
; CHECK-NEXT: ret <2 x i32> [[AND1]]
; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[CMP]], <2 x i32> [[XOR]], <2 x i32> [[AND1]]
; CHECK-NEXT: ret <2 x i32> [[COND]]
;
entry:
%and = and <2 x i32> %1, %0
Expand Down
12 changes: 11 additions & 1 deletion llvm/unittests/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,7 @@ TEST_F(PatternMatchTest, VScale) {
EXPECT_TRUE(match(PtrToInt2, m_VScale()));
}

TEST_F(PatternMatchTest, NotForbidUndef) {
TEST_F(PatternMatchTest, NotForbidPoison) {
Type *ScalarTy = IRB.getInt8Ty();
Type *VectorTy = FixedVectorType::get(ScalarTy, 3);
Constant *ScalarUndef = UndefValue::get(ScalarTy);
Expand All @@ -2020,23 +2020,33 @@ TEST_F(PatternMatchTest, NotForbidUndef) {
Value *X;
EXPECT_TRUE(match(Not, m_Not(m_Value(X))));
EXPECT_TRUE(match(X, m_Zero()));
X = nullptr;
EXPECT_TRUE(match(Not, m_NotForbidPoison(m_Value(X))));
EXPECT_TRUE(match(X, m_Zero()));

Value *NotCommute = IRB.CreateXor(VectorOnes, VectorZero);
Value *Y;
EXPECT_TRUE(match(NotCommute, m_Not(m_Value(Y))));
EXPECT_TRUE(match(Y, m_Zero()));
Y = nullptr;
EXPECT_TRUE(match(NotCommute, m_NotForbidPoison(m_Value(Y))));
EXPECT_TRUE(match(Y, m_Zero()));

Value *NotWithUndefs = IRB.CreateXor(VectorZero, VectorMixedUndef);
EXPECT_FALSE(match(NotWithUndefs, m_Not(m_Value())));
EXPECT_FALSE(match(NotWithUndefs, m_NotForbidPoison(m_Value())));

Value *NotWithPoisons = IRB.CreateXor(VectorZero, VectorMixedPoison);
EXPECT_TRUE(match(NotWithPoisons, m_Not(m_Value())));
EXPECT_FALSE(match(NotWithPoisons, m_NotForbidPoison(m_Value())));

Value *NotWithUndefsCommute = IRB.CreateXor(VectorMixedUndef, VectorZero);
EXPECT_FALSE(match(NotWithUndefsCommute, m_Not(m_Value())));
EXPECT_FALSE(match(NotWithUndefsCommute, m_NotForbidPoison(m_Value())));

Value *NotWithPoisonsCommute = IRB.CreateXor(VectorMixedPoison, VectorZero);
EXPECT_TRUE(match(NotWithPoisonsCommute, m_Not(m_Value())));
EXPECT_FALSE(match(NotWithPoisonsCommute, m_NotForbidPoison(m_Value())));
}

template <typename T> struct MutableConstTest : PatternMatchTest { };
Expand Down