Skip to content

[VPlan] Add getSCEVExprForVPValue util, use to get trip count SCEV (NFC) #94464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 18, 2024
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,9 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) {
return B.CreateElementCount(Ty, VF);
}

const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
Loop *OrigLoop) {
static const SCEV *createTripCountSCEV(Type *IdxTy,
PredicatedScalarEvolution &PSE,
Loop *OrigLoop) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly here or as follow-up:

The call to createTripCountSCEV() in LVP::selectEpilogueVectorizationFactor() has available VPlans from which to retrieve a TripCount VPValue and call the new API instead of continuing to call createTripCountSCEV().

The other two calls to createTripCountSCEV() are inside calls to VPlan::createInitialVPlan(), setting its first parameter; should they be inlined into createInitialVPlan() itself instead?
Thereby retiring createTripCountSCEV altogether, leaving only vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE) as the way to obtain the SCEV of a loop's trip count.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated the PR to also inline createTripCountSCEV.

const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount();
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");

Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF);
Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF,
int64_t Step);

const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
Loop *CurLoop = nullptr);

/// A helper function that returns the reciprocal of the block probability of
/// predicated blocks. If we return X, we are assuming the predicated block
/// will execute once for every X iterations of the loop header.
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "VPlanCFG.h"
#include "VPlanDominatorTree.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class VPTypeAnalysis {
// Collect a VPlan's ephemeral recipes (those used only by an assume).
void collectEphemeralRecipesForVPlan(VPlan &Plan,
DenseSet<VPRecipeBase *> &EphRecipes);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed, thanks!

} // end namespace llvm

#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H
5 changes: 2 additions & 3 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,9 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue())))))
return;

Type *IdxTy =
Plan.getCanonicalIV()->getStartValue()->getLiveInIRValue()->getType();
const SCEV *TripCount = createTripCountSCEV(IdxTy, PSE);
ScalarEvolution &SE = *PSE.getSE();
const SCEV *TripCount =
vputils::getSCEVExprForVPValue(Plan.getTripCount(), SE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, independently - BOC below would better use an unconditional branch than BranchOnCond with true condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will also entail to remove the region and header phis.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth asserting that the SCEV returned could be computed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added assert.

ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF);
const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements);
if (TripCount->isZero() ||
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "VPlanUtils.h"
#include "VPlanPatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"

using namespace llvm;
Expand Down Expand Up @@ -60,3 +61,14 @@ bool vputils::isHeaderMask(const VPValue *V, VPlan &Plan) {
return match(V, m_Binary<Instruction::ICmp>(m_VPValue(A), m_VPValue(B))) &&
IsWideCanonicalIV(A) && B == Plan.getOrCreateBackedgeTakenCount();
}

const SCEV *vputils::getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE) {
if (V->isLiveIn())
return SE.getSCEV(V->getLiveInIRValue());

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to work well for Values at VPlan's boundary: can complement live-ins with live-out/VPIRInstructions which could also simply retrieve SE.getSCEV() of the instruction they wrap; as follow-up, along with concrete use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

// TODO: Support constructing SCEVs for more recipes as needed.
return TypeSwitch<const VPRecipeBase *, const SCEV *>(V->getDefiningRecipe())
.Case<VPExpandSCEVRecipe>(
[](const VPExpandSCEVRecipe *R) { return R->getSCEV(); })
.Default([&SE](const VPRecipeBase *) { return SE.getCouldNotCompute(); });
}
10 changes: 10 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@

#include "VPlan.h"

namespace llvm {
class ScalarEvolution;
class SCEV;
} // namespace llvm

namespace llvm::vputils {
/// Returns true if only the first lane of \p Def is used.
bool onlyFirstLaneUsed(const VPValue *Def);
Expand Down Expand Up @@ -45,6 +50,11 @@ inline bool isUniformAfterVectorization(const VPValue *VPV) {

/// Return true if \p V is a header mask in \p Plan.
bool isHeaderMask(const VPValue *V, VPlan &Plan);

/// Return the SCEV expression for \p V. Returns SCEVCouldNotCompute if no
/// SCEV expression could be constructed.
const SCEV *getSCEVExprForVPValue(VPValue *V, ScalarEvolution &SE);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This complements getOrCreateVPValueForSCEVExpr() above and should be placed next to it.
(Otherwise would have suffice to call it getSCEVFor() or getSCEV(), as it currently delegates to SE.getSCEV() or SCEVExpandRecioe.getSCEV(). Could even have getOrCreateSCEVExprForVPValue() more consistently, but getSCEV() may typically create them - in contrast to getExistingSCEV().)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved, thanks!

} // end namespace llvm::vputils

#endif
Loading