Skip to content

VPlan: increase simplification power of simplifyRecipe #93998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,33 +161,34 @@ class VPBuilder {
return tryInsertInstruction(
new VPInstruction(Opcode, Operands, WrapFlags, DL, Name));
}
VPValue *createNot(VPValue *Operand, DebugLoc DL = {},
const Twine &Name = "") {
VPInstruction *createNot(VPValue *Operand, DebugLoc DL = {},
const Twine &Name = "") {
return createInstruction(VPInstruction::Not, {Operand}, DL, Name);
}

VPValue *createAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {
VPInstruction *createAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {
return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}, DL, Name);
}

VPValue *createOr(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {
VPInstruction *createOr(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {

return tryInsertInstruction(new VPInstruction(
Instruction::BinaryOps::Or, {LHS, RHS},
VPRecipeWithIRFlags::DisjointFlagsTy(false), DL, Name));
}

VPValue *createLogicalAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {
VPInstruction *createLogicalAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {},
const Twine &Name = "") {
return tryInsertInstruction(
new VPInstruction(VPInstruction::LogicalAnd, {LHS, RHS}, DL, Name));
}

VPValue *createSelect(VPValue *Cond, VPValue *TrueVal, VPValue *FalseVal,
DebugLoc DL = {}, const Twine &Name = "",
std::optional<FastMathFlags> FMFs = std::nullopt) {
VPInstruction *
createSelect(VPValue *Cond, VPValue *TrueVal, VPValue *FalseVal,
DebugLoc DL = {}, const Twine &Name = "",
std::optional<FastMathFlags> FMFs = std::nullopt) {
auto *Select =
FMFs ? new VPInstruction(Instruction::Select, {Cond, TrueVal, FalseVal},
*FMFs, DL, Name)
Expand All @@ -199,8 +200,8 @@ class VPBuilder {
/// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A
/// and \p B.
/// TODO: add createFCmp when needed.
VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
DebugLoc DL = {}, const Twine &Name = "");
VPInstruction *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
DebugLoc DL = {}, const Twine &Name = "");

//===--------------------------------------------------------------------===//
// RAII helpers.
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6931,8 +6931,9 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
}
}

VPValue *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
DebugLoc DL, const Twine &Name) {
VPInstruction *VPBuilder::createICmp(CmpInst::Predicate Pred, VPValue *A,
VPValue *B, DebugLoc DL,
const Twine &Name) {
assert(Pred >= CmpInst::FIRST_ICMP_PREDICATE &&
Pred <= CmpInst::LAST_ICMP_PREDICATE && "invalid predicate");
return tryInsertInstruction(
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ template <unsigned BitWidth = 0> struct specific_intval {
if (!CI)
return false;

assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
"Trying the match constant with unexpected bitwidth.");
if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
return false;

return APInt::isSameValue(CI->getValue(), Val);
}
};
Expand All @@ -87,6 +88,8 @@ inline specific_intval<0> m_SpecificInt(uint64_t V) {

inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }

inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }

/// Matching combinators
template <typename LTy, typename RTy> struct match_combine_or {
LTy L;
Expand Down
101 changes: 81 additions & 20 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
#include <deque>

using namespace llvm;

Expand Down Expand Up @@ -852,8 +853,10 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
}
}

/// Try to simplify recipe \p R.
static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
/// Try to simplify recipe \p R. Returns any new recipes introduced during
/// simplification, as candidates for further simplification.
static SmallVector<VPRecipeBase *>
simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo, VPlan &Plan) {
using namespace llvm::VPlanPatternMatch;

if (auto *Blend = dyn_cast<VPBlendRecipe>(&R)) {
Expand All @@ -868,11 +871,11 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
if (UniqueValues.size() == 1) {
Blend->replaceAllUsesWith(*UniqueValues.begin());
Blend->eraseFromParent();
return;
return {};
}

if (Blend->isNormalized())
return;
return {};

// Normalize the blend so its first incoming value is used as the initial
// value with the others blended into it.
Expand Down Expand Up @@ -907,7 +910,7 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
Blend->replaceAllUsesWith(NewBlend);
Blend->eraseFromParent();
recursivelyDeleteDeadRecipes(DeadMask);
return;
return {};
}

VPValue *A;
Expand All @@ -920,7 +923,7 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
} else {
// Don't replace a scalarizing recipe with a widened cast.
if (isa<VPReplicateRecipe>(&R))
return;
return {};
if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) {

unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue()))
Expand Down Expand Up @@ -955,24 +958,73 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV));
}
#endif
return {};
}

VPValue *X, *X1, *Y, *Z;
LLVMContext &Ctx = TypeInfo.getContext();

// (X || !X) -> true.
if (match(&R, m_c_BinaryOr(m_VPValue(X), m_Not(m_VPValue(X1)))) && X == X1) {
VPValue *VPV = Plan.getOrAddLiveIn(ConstantInt::getTrue(Ctx));
R.getVPSingleValue()->replaceAllUsesWith(VPV);
return {};
}

// Simplify (X && Y) || (X && !Y) -> X.
// TODO: Split up into simpler, modular combines: (X && Y) || (X && Z) into X
// && (Y || Z) and (X || !X) into true. This requires queuing newly created
// recipes to be visited during simplification.
VPValue *X, *Y, *X1, *Y1;
if (match(&R,
m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) &&
X == X1 && Y == Y1) {
// (X || true) -> true.
if (match(&R, m_c_BinaryOr(m_VPValue(X), m_True()))) {
VPValue *VPV = Plan.getOrAddLiveIn(ConstantInt::getTrue(Ctx));
R.getVPSingleValue()->replaceAllUsesWith(VPV);
return {};
}

// (X || false) -> X.
if (match(&R, m_c_BinaryOr(m_VPValue(X), m_False()))) {
R.getVPSingleValue()->replaceAllUsesWith(X);
return {};
}

// (X && !X) -> false.
if (match(&R, m_LogicalAnd(m_VPValue(X), m_Not(m_VPValue(X1)))) && X == X1) {
VPValue *VPV = Plan.getOrAddLiveIn(ConstantInt::getFalse(Ctx));
R.getVPSingleValue()->replaceAllUsesWith(VPV);
return {};
}

// (X && true) -> X.
if (match(&R, m_LogicalAnd(m_VPValue(X), m_True()))) {
R.getVPSingleValue()->replaceAllUsesWith(X);
return {};
}

// (X && false) -> false.
if (match(&R, m_LogicalAnd(m_VPValue(X), m_False()))) {
VPValue *VPV = Plan.getOrAddLiveIn(ConstantInt::getFalse(Ctx));
R.getVPSingleValue()->replaceAllUsesWith(VPV);
return {};
}

// (X * 1) -> X.
if (match(&R, m_c_Mul(m_VPValue(X), m_SpecificInt(1)))) {
R.getVPSingleValue()->replaceAllUsesWith(X);
return {};
}

// (X && Y) || (X && Z) -> X && (Y || Z).
if (match(&R, m_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)),
m_LogicalAnd(m_VPValue(X1), m_VPValue(Z)))) &&
X == X1) {
VPBuilder Builder(&R);
VPInstruction *YorZ = Builder.createOr(Y, Z, R.getDebugLoc());
VPInstruction *VPI = Builder.createLogicalAnd(X, YorZ, R.getDebugLoc());
R.getVPSingleValue()->replaceAllUsesWith(VPI);
R.eraseFromParent();
return;
// Order of simplification matters: simplify sub-recipes before root
// recipes.
return {YorZ, VPI};
}

if (match(&R, m_c_Mul(m_VPValue(A), m_SpecificInt(1))))
return R.getVPSingleValue()->replaceAllUsesWith(A);
return {};
}

/// Try to simplify the recipes in \p Plan.
Expand All @@ -981,8 +1033,17 @@ static void simplifyRecipes(VPlan &Plan, LLVMContext &Ctx) {
Plan.getEntry());
VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType(), Ctx);
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) {
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
simplifyRecipe(R, TypeInfo);
// Order of simplification matters: add new candidates for simplification to
// the back of the Worklist, while the Worklist processes recipes from the
// front.
std::deque<VPRecipeBase *> Worklist;
for (auto &R : make_early_inc_range(*VPBB)) {
Worklist.emplace_front(&R);
while (!Worklist.empty()) {
VPRecipeBase *R = Worklist.front();
Worklist.pop_front();
append_range(Worklist, simplifyRecipe(*R, TypeInfo, Plan));
}
}
}
}
Expand Down
Loading
Loading