Skip to content

[VPlan] Add commutative binary OR matcher, use in transform. #92539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ using AllUnaryRecipe_match =
UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
VPWidenCastRecipe, VPInstruction>;

template <typename Op0_t, typename Op1_t, unsigned Opcode,
template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative,
typename... RecipeTys>
struct BinaryRecipe_match {
Op0_t Op0;
Expand All @@ -179,18 +179,23 @@ struct BinaryRecipe_match {
return false;
assert(R->getNumOperands() == 2 &&
"recipe with matched opcode does not have 2 operands");
return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)))
return true;
return Commutative && Op0.match(R->getOperand(1)) &&
Op1.match(R->getOperand(0));
}
};

template <typename Op0_t, typename Op1_t, unsigned Opcode>
using BinaryVPInstruction_match =
BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPInstruction>;
BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false,
VPInstruction>;

template <typename Op0_t, typename Op1_t, unsigned Opcode>
template <typename Op0_t, typename Op1_t, unsigned Opcode,
bool Commutative = false>
using AllBinaryRecipe_match =
BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
VPWidenCastRecipe, VPInstruction>;
BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;

template <unsigned Opcode, typename Op0_t>
inline UnaryVPInstruction_match<Op0_t, Opcode>
Expand Down Expand Up @@ -256,10 +261,11 @@ m_ZExtOrSExt(const Op0_t &Op0) {
return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
}

template <unsigned Opcode, typename Op0_t, typename Op1_t>
inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode> m_Binary(const Op0_t &Op0,
const Op1_t &Op1) {
return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
template <unsigned Opcode, typename Op0_t, typename Op1_t,
bool Commutative = false>
inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>
m_Binary(const Op0_t &Op0, const Op1_t &Op1) {
return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
Expand All @@ -268,10 +274,21 @@ m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or>
/// Match a binary OR operation. Note that while conceptually the operands can
/// be matched commutatively, \p Commutative defaults to false in line with the
/// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative
/// version of the matcher.
template <typename Op0_t, typename Op1_t, bool Commutative = false>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the default for Or be Commutative = true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default (=false) is in line with the IR based pattern matcher I think. It only has the argument to allow for a slightly simpler m_c_BinaryOr implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, worth a comment, as conceptually matching Or's should be commutative by default, so unclear when one would use the non commutative version, which could still be provided albeit as opt in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in the latest version, thanks!

inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative>
m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or,
/*Commutative*/ true>
m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"m_c_" stands for commutative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is like in the IR based pattern matcher.

return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,8 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
// recipes to be visited during simplification.
VPValue *X, *Y, *X1, *Y1;
if (match(&R,
m_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) &&
m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) &&
X == X1 && Y == Y1) {
R.getVPSingleValue()->replaceAllUsesWith(X);
return;
Expand Down
9 changes: 3 additions & 6 deletions llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,9 @@ define void @test_widen_if_then_else(ptr noalias %a, ptr readnone %b) #4 {
; TFCOMMON-NEXT: [[TMP11:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> zeroinitializer, <vscale x 2 x i1> [[TMP10]])
; TFCOMMON-NEXT: [[TMP12:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x i1> [[TMP8]], <vscale x 2 x i1> zeroinitializer
; TFCOMMON-NEXT: [[TMP13:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD]], <vscale x 2 x i1> [[TMP12]])
; TFCOMMON-NEXT: [[TMP14:%.*]] = or <vscale x 2 x i1> [[TMP10]], [[TMP12]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and below match X&&!Y || X&&Y --> X, exercising commutativity.

; TFCOMMON-NEXT: [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP10]], <vscale x 2 x i64> [[TMP11]], <vscale x 2 x i64> [[TMP13]]
; TFCOMMON-NEXT: [[TMP15:%.*]] = getelementptr inbounds i64, ptr [[B:%.*]], i64 [[INDEX]]
; TFCOMMON-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP15]], i32 8, <vscale x 2 x i1> [[TMP14]])
; TFCOMMON-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP15]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
; TFCOMMON-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP6]]
; TFCOMMON-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX_NEXT]], i64 1025)
; TFCOMMON-NEXT: [[TMP16:%.*]] = xor <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> poison, i1 true, i64 0), <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer)
Expand Down Expand Up @@ -453,16 +452,14 @@ define void @test_widen_if_then_else(ptr noalias %a, ptr readnone %b) #4 {
; TFA_INTERLEAVE-NEXT: [[TMP22:%.*]] = select <vscale x 2 x i1> [[ACTIVE_LANE_MASK2]], <vscale x 2 x i1> [[TMP14]], <vscale x 2 x i1> zeroinitializer
; TFA_INTERLEAVE-NEXT: [[TMP23:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD]], <vscale x 2 x i1> [[TMP21]])
; TFA_INTERLEAVE-NEXT: [[TMP24:%.*]] = call <vscale x 2 x i64> @foo_vector(<vscale x 2 x i64> [[WIDE_MASKED_LOAD3]], <vscale x 2 x i1> [[TMP22]])
; TFA_INTERLEAVE-NEXT: [[TMP25:%.*]] = or <vscale x 2 x i1> [[TMP17]], [[TMP21]]
; TFA_INTERLEAVE-NEXT: [[TMP26:%.*]] = or <vscale x 2 x i1> [[TMP18]], [[TMP22]]
; TFA_INTERLEAVE-NEXT: [[PREDPHI:%.*]] = select <vscale x 2 x i1> [[TMP17]], <vscale x 2 x i64> [[TMP19]], <vscale x 2 x i64> [[TMP23]]
; TFA_INTERLEAVE-NEXT: [[PREDPHI4:%.*]] = select <vscale x 2 x i1> [[TMP18]], <vscale x 2 x i64> [[TMP20]], <vscale x 2 x i64> [[TMP24]]
; TFA_INTERLEAVE-NEXT: [[TMP27:%.*]] = getelementptr inbounds i64, ptr [[B:%.*]], i64 [[INDEX]]
; TFA_INTERLEAVE-NEXT: [[TMP28:%.*]] = call i64 @llvm.vscale.i64()
; TFA_INTERLEAVE-NEXT: [[TMP29:%.*]] = mul i64 [[TMP28]], 2
; TFA_INTERLEAVE-NEXT: [[TMP30:%.*]] = getelementptr inbounds i64, ptr [[TMP27]], i64 [[TMP29]]
; TFA_INTERLEAVE-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP27]], i32 8, <vscale x 2 x i1> [[TMP25]])
; TFA_INTERLEAVE-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI4]], ptr [[TMP30]], i32 8, <vscale x 2 x i1> [[TMP26]])
; TFA_INTERLEAVE-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI]], ptr [[TMP27]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
; TFA_INTERLEAVE-NEXT: call void @llvm.masked.store.nxv2i64.p0(<vscale x 2 x i64> [[PREDPHI4]], ptr [[TMP30]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK2]])
; TFA_INTERLEAVE-NEXT: [[INDEX_NEXT:%.*]] = add i64 [[INDEX]], [[TMP6]]
; TFA_INTERLEAVE-NEXT: [[TMP31:%.*]] = call i64 @llvm.vscale.i64()
; TFA_INTERLEAVE-NEXT: [[TMP32:%.*]] = mul i64 [[TMP31]], 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,10 @@ define void @cond_uniform_load(ptr noalias %dst, ptr noalias readonly %src, ptr
; CHECK-NEXT: [[TMP15:%.*]] = select <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i1> [[TMP14]], <vscale x 4 x i1> zeroinitializer
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[BROADCAST_SPLAT]], i32 4, <vscale x 4 x i1> [[TMP15]], <vscale x 4 x i32> poison)
; CHECK-NEXT: [[TMP16:%.*]] = select <vscale x 4 x i1> [[ACTIVE_LANE_MASK]], <vscale x 4 x i1> [[TMP13]], <vscale x 4 x i1> zeroinitializer
; CHECK-NEXT: [[TMP18:%.*]] = or <vscale x 4 x i1> [[TMP15]], [[TMP16]]
; CHECK-NEXT: [[PREDPHI:%.*]] = select <vscale x 4 x i1> [[TMP16]], <vscale x 4 x i32> zeroinitializer, <vscale x 4 x i32> [[WIDE_MASKED_GATHER]]
; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[DST:%.*]], i64 [[TMP10]]
; CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 0
; CHECK-NEXT: call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[PREDPHI]], ptr [[TMP19]], i32 4, <vscale x 4 x i1> [[TMP18]])
; CHECK-NEXT: call void @llvm.masked.store.nxv4i32.p0(<vscale x 4 x i32> [[PREDPHI]], ptr [[TMP19]], i32 4, <vscale x 4 x i1> [[ACTIVE_LANE_MASK]])
; CHECK-NEXT: [[INDEX_NEXT2]] = add i64 [[INDEX1]], [[TMP21]]
; CHECK-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 [[INDEX1]], i64 [[TMP9]])
; CHECK-NEXT: [[TMP22:%.*]] = xor <vscale x 4 x i1> [[ACTIVE_LANE_MASK_NEXT]], shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer)
Expand Down
Loading