@@ -1007,29 +1007,37 @@ void VPlanTransforms::simplifyRecipes(VPlan &Plan, Type &CanonicalIVTy) {
1007
1007
1008
1008
// / Return true if \p Cond is known to be true for given \p BestVF and \p
1009
1009
// / BestUF.
1010
- static bool isConditionKnown (VPValue *Cond, VPlan &Plan, ElementCount BestVF,
1011
- unsigned BestUF, ScalarEvolution &SE) {
1010
+ static bool isConditionTrueViaVFAndUF (VPValue *Cond, VPlan &Plan,
1011
+ ElementCount BestVF, unsigned BestUF,
1012
+ ScalarEvolution &SE) {
1012
1013
using namespace llvm ::VPlanPatternMatch;
1013
1014
if (match (Cond, m_Binary<Instruction::Or>(m_VPValue (), m_VPValue ())))
1014
- return any_of (Cond->getDefiningRecipe ()->operands (),
1015
- [&Plan, BestVF, BestUF, &SE](VPValue *C) {
1016
- return isConditionKnown (C, Plan, BestVF, BestUF, SE);
1017
- });
1015
+ return any_of (Cond->getDefiningRecipe ()->operands (), [&Plan, BestVF, BestUF,
1016
+ &SE](VPValue *C) {
1017
+ return isConditionTrueViaVFAndUF (C, Plan, BestVF, BestUF, SE);
1018
+ });
1018
1019
1019
- VPValue *TripCount = Plan.getTripCount ();
1020
1020
auto *CanIV = Plan.getCanonicalIV ();
1021
- if (!match (Cond, m_Binary<Instruction::ICmp>(m_Specific (CanIV),
1022
- m_VPValue (TripCount))) ||
1021
+ if (!match (Cond, m_Binary<Instruction::ICmp>(
1022
+ m_Specific (CanIV->getBackedgeValue ()),
1023
+ m_Specific (&Plan.getVectorTripCount ()))) ||
1023
1024
cast<VPRecipeWithIRFlags>(Cond->getDefiningRecipe ())->getPredicate () !=
1024
1025
CmpInst::ICMP_EQ)
1025
1026
return false ;
1026
1027
1027
- const SCEV *TripCountSCEV = vputils::getSCEVExprForVPValue (TripCount, SE);
1028
- assert (!isa<SCEVCouldNotCompute>(TripCountSCEV) &&
1028
+ // The compare checks CanIV + VFxUF == vector trip count. The vector trip
1029
+ // count is not conveniently available as SCEV so far, so we compare directly
1030
+ // against the original trip count. This is stricter than necessary, as we
1031
+ // will only return true if the trip count == vector trip count.
1032
+ // TODO: Use SCEV for vector trip count once available, to cover cases where
1033
+ // vector trip count == UF * VF, but original trip count != UF * VF.
1034
+ const SCEV *TripCount =
1035
+ vputils::getSCEVExprForVPValue (Plan.getTripCount (), SE);
1036
+ assert (!isa<SCEVCouldNotCompute>(TripCount) &&
1029
1037
" Trip count SCEV must be computable" );
1030
1038
ElementCount NumElements = BestVF.multiplyCoefficientBy (BestUF);
1031
- const SCEV *C = SE.getElementCount (TripCountSCEV ->getType (), NumElements);
1032
- return SE.isKnownPredicate (CmpInst::ICMP_EQ, TripCountSCEV , C);
1039
+ const SCEV *C = SE.getElementCount (TripCount ->getType (), NumElements);
1040
+ return SE.isKnownPredicate (CmpInst::ICMP_EQ, TripCount , C);
1033
1041
}
1034
1042
1035
1043
void VPlanTransforms::optimizeForVFAndUF (VPlan &Plan, ElementCount BestVF,
@@ -1040,30 +1048,32 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,
1040
1048
VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion ();
1041
1049
VPBasicBlock *ExitingVPBB = VectorRegion->getExitingBasicBlock ();
1042
1050
auto *Term = &ExitingVPBB->back ();
1043
- // Try to simplify the branch condition if TC <= VF * UF when preparing to
1044
- // execute the plan for the main vector loop. We only do this if the
1045
- // terminator is:
1046
- // 1. BranchOnCount, or
1047
- // 2. BranchOnCond where the input is Not(ActiveLaneMask).
1048
- using namespace llvm ::VPlanPatternMatch;
1049
1051
VPValue *Cond;
1050
- if (!match (Term, m_BranchOnCount (m_VPValue (), m_VPValue ())) &&
1051
- !match (Term, m_BranchOnCond (
1052
- m_Not (m_ActiveLaneMask (m_VPValue (), m_VPValue ())))) &&
1053
- (!match (Term, m_BranchOnCond (m_VPValue (Cond))) ||
1054
- isConditionKnown (Cond, Plan, BestVF, BestUF, *PSE.getSE ())))
1055
- return ;
1056
-
1057
1052
ScalarEvolution &SE = *PSE.getSE ();
1058
- const SCEV *TripCount =
1059
- vputils::getSCEVExprForVPValue (Plan.getTripCount (), SE);
1060
- assert (!isa<SCEVCouldNotCompute>(TripCount) &&
1061
- " Trip count SCEV must be computable" );
1062
- ElementCount NumElements = BestVF.multiplyCoefficientBy (BestUF);
1063
- const SCEV *C = SE.getElementCount (TripCount->getType (), NumElements);
1064
- if (TripCount->isZero () ||
1065
- !SE.isKnownPredicate (CmpInst::ICMP_ULE, TripCount, C))
1053
+ using namespace llvm ::VPlanPatternMatch;
1054
+ if (match (Term, m_BranchOnCount (m_VPValue (), m_VPValue ())) ||
1055
+ match (Term, m_BranchOnCond (
1056
+ m_Not (m_ActiveLaneMask (m_VPValue (), m_VPValue ()))))) {
1057
+ // Try to simplify the branch condition if TC <= VF * UF when the latch
1058
+ // terminator is BranchOnCount or BranchOnCond where the input is
1059
+ // Not(ActiveLaneMask).
1060
+ const SCEV *TripCount =
1061
+ vputils::getSCEVExprForVPValue (Plan.getTripCount (), SE);
1062
+ assert (!isa<SCEVCouldNotCompute>(TripCount) &&
1063
+ " Trip count SCEV must be computable" );
1064
+ ElementCount NumElements = BestVF.multiplyCoefficientBy (BestUF);
1065
+ const SCEV *C = SE.getElementCount (TripCount->getType (), NumElements);
1066
+ if (TripCount->isZero () ||
1067
+ !SE.isKnownPredicate (CmpInst::ICMP_ULE, TripCount, C))
1068
+ return ;
1069
+ } else if (match (Term, m_BranchOnCond (m_VPValue (Cond)))) {
1070
+ // For BranchOnCond, check if we can prove the condition to be true using VF
1071
+ // and UF.
1072
+ if (!isConditionTrueViaVFAndUF (Cond, Plan, BestVF, BestUF, SE))
1073
+ return ;
1074
+ } else {
1066
1075
return ;
1076
+ }
1067
1077
1068
1078
// The vector loop region only executes once. If possible, completely remove
1069
1079
// the region, otherwise replace the terminator controlling the latch with
0 commit comments