Skip to content

Commit fd93a5e

Browse files
committed
[VPlan] Support match unary and binary recipes in pattern matcher (NFC).
Generalize pattern matchers to take recipe types to match as template arguments and use it to provide matchers for unary and binary recipes with specific opcodes and a list of recipe types (VPWidenRecipe, VPReplicateRecipe, VPWidenCastRecipe, VPInstruction) The new matchers are used to simplify and generalize the code in simplifyRecipes.
1 parent bfd1d95 commit fd93a5e

File tree

3 files changed

+158
-57
lines changed

3 files changed

+158
-57
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,8 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags {
21362136
assert(isPredicated() && "Trying to get the mask of a unpredicated recipe");
21372137
return getOperand(getNumOperands() - 1);
21382138
}
2139+
2140+
unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); }
21392141
};
21402142

21412143
/// A recipe for generating conditional branches on the bits of a mask.

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 144 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,51 +50,141 @@ template <typename Class> struct bind_ty {
5050
}
5151
};
5252

53+
/// Match a specified integer value or vector of all elements of that
54+
/// value.
55+
struct specific_intval {
56+
APInt Val;
57+
58+
specific_intval(APInt V) : Val(std::move(V)) {}
59+
60+
bool match(VPValue *VPV) {
61+
if (!VPV->isLiveIn())
62+
return false;
63+
Value *V = VPV->getLiveInIRValue();
64+
const auto *CI = dyn_cast<ConstantInt>(V);
65+
if (!CI && V->getType()->isVectorTy())
66+
if (const auto *C = dyn_cast<Constant>(V))
67+
CI = dyn_cast_or_null<ConstantInt>(
68+
C->getSplatValue(/*UndefsAllowed=*/false));
69+
70+
return CI && APInt::isSameValue(CI->getValue(), Val);
71+
}
72+
};
73+
74+
inline specific_intval m_SpecificInt(uint64_t V) {
75+
return specific_intval(APInt(64, V));
76+
}
77+
78+
/// Matching combinators
79+
template <typename LTy, typename RTy> struct match_combine_or {
80+
LTy L;
81+
RTy R;
82+
83+
match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
84+
85+
template <typename ITy> bool match(ITy *V) {
86+
if (L.match(V))
87+
return true;
88+
if (R.match(V))
89+
return true;
90+
return false;
91+
}
92+
};
93+
94+
template <typename LTy, typename RTy>
95+
inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
96+
return match_combine_or<LTy, RTy>(L, R);
97+
}
98+
5399
/// Match a VPValue, capturing it if we match.
54100
inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
55101

56-
template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
102+
namespace detail {
103+
104+
/// A helper to match an opcode against multiple recipe types.
105+
template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
106+
107+
template <unsigned Opcode, typename RecipeTy>
108+
struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
109+
static bool match(const VPRecipeBase *R) {
110+
auto *DefR = dyn_cast<RecipeTy>(R);
111+
return DefR && DefR->getOpcode() == Opcode;
112+
}
113+
};
114+
115+
template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
116+
struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
117+
static bool match(const VPRecipeBase *R) {
118+
return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
119+
MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
120+
}
121+
};
122+
} // namespace detail
123+
124+
template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
125+
struct UnaryRecipe_match {
57126
Op0_t Op0;
58127

59-
UnaryVPInstruction_match(Op0_t Op0) : Op0(Op0) {}
128+
UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {}
60129

61130
bool match(const VPValue *V) {
62131
auto *DefR = V->getDefiningRecipe();
63132
return DefR && match(DefR);
64133
}
65134

66135
bool match(const VPRecipeBase *R) {
67-
auto *DefR = dyn_cast<VPInstruction>(R);
68-
if (!DefR || DefR->getOpcode() != Opcode)
136+
if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
69137
return false;
70-
assert(DefR->getNumOperands() == 1 &&
138+
assert(R->getNumOperands() == 1 &&
71139
"recipe with matched opcode does not have 1 operands");
72-
return Op0.match(DefR->getOperand(0));
140+
return Op0.match(R->getOperand(0));
73141
}
74142
};
75143

76-
template <typename Op0_t, typename Op1_t, unsigned Opcode>
77-
struct BinaryVPInstruction_match {
144+
template <typename Op0_t, unsigned Opcode>
145+
using UnaryVPInstruction_match =
146+
UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
147+
148+
template <typename Op0_t, unsigned Opcode>
149+
using AllUnaryRecipe_match =
150+
UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
151+
VPWidenCastRecipe, VPInstruction>;
152+
153+
template <typename Op0_t, typename Op1_t, unsigned Opcode,
154+
typename... RecipeTys>
155+
struct BinaryRecipe_match {
78156
Op0_t Op0;
79157
Op1_t Op1;
80158

81-
BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
159+
BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
82160

83161
bool match(const VPValue *V) {
84162
auto *DefR = V->getDefiningRecipe();
85163
return DefR && match(DefR);
86164
}
87165

166+
bool match(const VPSingleDefRecipe *R) {
167+
return match(static_cast<const VPRecipeBase *>(R));
168+
}
169+
88170
bool match(const VPRecipeBase *R) {
89-
auto *DefR = dyn_cast<VPInstruction>(R);
90-
if (!DefR || DefR->getOpcode() != Opcode)
171+
if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
91172
return false;
92-
assert(DefR->getNumOperands() == 2 &&
173+
assert(R->getNumOperands() == 2 &&
93174
"recipe with matched opcode does not have 2 operands");
94-
return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1));
175+
return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
95176
}
96177
};
97178

179+
template <typename Op0_t, typename Op1_t, unsigned Opcode>
180+
using BinaryVPInstruction_match =
181+
BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPInstruction>;
182+
183+
template <typename Op0_t, typename Op1_t, unsigned Opcode>
184+
using AllBinaryRecipe_match =
185+
BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
186+
VPWidenCastRecipe, VPInstruction>;
187+
98188
template <unsigned Opcode, typename Op0_t>
99189
inline UnaryVPInstruction_match<Op0_t, Opcode>
100190
m_VPInstruction(const Op0_t &Op0) {
@@ -130,6 +220,47 @@ inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
130220
m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
131221
return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
132222
}
223+
224+
template <unsigned Opcode, typename Op0_t>
225+
inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
226+
return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
227+
}
228+
229+
template <typename Op0_t>
230+
inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
231+
m_Trunc(const Op0_t &Op0) {
232+
return m_Unary<Instruction::Trunc, Op0_t>(Op0);
233+
}
234+
235+
template <typename Op0_t>
236+
inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
237+
return m_Unary<Instruction::ZExt, Op0_t>(Op0);
238+
}
239+
240+
template <typename Op0_t>
241+
inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
242+
return m_Unary<Instruction::SExt, Op0_t>(Op0);
243+
}
244+
245+
template <typename Op0_t>
246+
inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
247+
AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
248+
m_ZExtOrSExt(const Op0_t &Op0) {
249+
return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
250+
}
251+
252+
template <unsigned Opcode, typename Op0_t, typename Op1_t>
253+
inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode> m_Binary(const Op0_t &Op0,
254+
const Op1_t &Op1) {
255+
return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
256+
}
257+
258+
template <typename Op0_t, typename Op1_t>
259+
inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
260+
m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
261+
return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
262+
}
263+
133264
} // namespace VPlanPatternMatch
134265
} // namespace llvm
135266

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -814,27 +814,6 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
814814
}
815815
}
816816

817-
/// Returns true is \p V is constant one.
818-
static bool isConstantOne(VPValue *V) {
819-
if (!V->isLiveIn())
820-
return false;
821-
auto *C = dyn_cast<ConstantInt>(V->getLiveInIRValue());
822-
return C && C->isOne();
823-
}
824-
825-
/// Returns the llvm::Instruction opcode for \p R.
826-
static unsigned getOpcodeForRecipe(VPRecipeBase &R) {
827-
if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R))
828-
return WidenR->getUnderlyingInstr()->getOpcode();
829-
if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R))
830-
return WidenC->getOpcode();
831-
if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R))
832-
return RepR->getUnderlyingInstr()->getOpcode();
833-
if (auto *VPI = dyn_cast<VPInstruction>(&R))
834-
return VPI->getOpcode();
835-
return 0;
836-
}
837-
838817
/// Try to simplify recipe \p R.
839818
static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
840819
// Try to remove redundant blend recipes.
@@ -848,24 +827,9 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
848827
return;
849828
}
850829

851-
switch (getOpcodeForRecipe(R)) {
852-
case Instruction::Mul: {
853-
VPValue *A = R.getOperand(0);
854-
VPValue *B = R.getOperand(1);
855-
if (isConstantOne(A))
856-
return R.getVPSingleValue()->replaceAllUsesWith(B);
857-
if (isConstantOne(B))
858-
return R.getVPSingleValue()->replaceAllUsesWith(A);
859-
break;
860-
}
861-
case Instruction::Trunc: {
862-
VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe();
863-
if (!Ext)
864-
break;
865-
unsigned ExtOpcode = getOpcodeForRecipe(*Ext);
866-
if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt)
867-
break;
868-
VPValue *A = Ext->getOperand(0);
830+
using namespace llvm::VPlanPatternMatch;
831+
VPValue *A;
832+
if (match(&R, m_Trunc(m_ZExtOrSExt(m_VPValue(A))))) {
869833
VPValue *Trunc = R.getVPSingleValue();
870834
Type *TruncTy = TypeInfo.inferScalarType(Trunc);
871835
Type *ATy = TypeInfo.inferScalarType(A);
@@ -874,8 +838,12 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
874838
} else {
875839
// Don't replace a scalarizing recipe with a widened cast.
876840
if (isa<VPReplicateRecipe>(&R))
877-
break;
841+
return;
878842
if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) {
843+
844+
unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue()))
845+
? Instruction::SExt
846+
: Instruction::ZExt;
879847
auto *VPC =
880848
new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy);
881849
VPC->insertBefore(&R);
@@ -901,11 +869,11 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
901869
assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV));
902870
}
903871
#endif
904-
break;
905-
}
906-
default:
907-
break;
908872
}
873+
874+
if (match(&R, m_CombineOr(m_Mul(m_VPValue(A), m_SpecificInt(1)),
875+
m_Mul(m_SpecificInt(1), m_VPValue(A)))))
876+
return R.getVPSingleValue()->replaceAllUsesWith(A);
909877
}
910878

911879
/// Try to simplify the recipes in \p Plan.

0 commit comments

Comments
 (0)