Skip to content

Commit bb017c1

Browse files
committed
[VPlan] Add support for in-loop AnyOf reductions
1 parent ad9630d commit bb017c1

File tree

7 files changed

+1766
-27
lines changed

7 files changed

+1766
-27
lines changed

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,8 @@ RecurrenceDescriptor::getReductionOpChain(PHINode *Phi, Loop *L) const {
12101210
return SelectPatternResult::isMinOrMax(
12111211
matchSelectPattern(Cur, LHS, RHS).Flavor);
12121212
}
1213+
if (isAnyOfRecurrenceKind(getRecurrenceKind()))
1214+
return isa<SelectInst>(Cur);
12131215
// Recognize a call to the llvm.fmuladd intrinsic.
12141216
if (isFMulAddIntrinsic(Cur))
12151217
return true;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5834,6 +5834,14 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
58345834
Intrinsic::ID MinMaxID = getMinMaxReductionIntrinsicOp(RK);
58355835
BaseCost = TTI.getMinMaxReductionCost(MinMaxID, VectorTy,
58365836
RdxDesc.getFastMathFlags(), CostKind);
5837+
} else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
5838+
VectorType *BoolTy = VectorType::get(
5839+
Type::getInt1Ty(VectorTy->getContext()), VectorTy->getElementCount());
5840+
BaseCost =
5841+
TTI.getArithmeticReductionCost(Instruction::Or, BoolTy,
5842+
RdxDesc.getFastMathFlags(), CostKind) +
5843+
TTI.getArithmeticInstrCost(Instruction::Or, BoolTy->getScalarType(),
5844+
CostKind);
58375845
} else {
58385846
BaseCost = TTI.getArithmeticReductionCost(
58395847
RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind);
@@ -9666,10 +9674,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96669674

96679675
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
96689676
RecurKind Kind = RdxDesc.getRecurrenceKind();
9669-
assert(
9670-
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
9671-
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9672-
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
9677+
assert(!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9678+
"FindLast reductions are not allowed for in-loop reductions");
96739679

96749680
// Collect the chain of "link" recipes for the reduction starting at PhiR.
96759681
SetVector<VPSingleDefRecipe *> Worklist;
@@ -9738,6 +9744,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
97389744
CurrentLinkI->getFastMathFlags());
97399745
LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator());
97409746
VecOp = FMulRecipe;
9747+
} else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind)) {
9748+
assert(isa<VPWidenSelectRecipe>(CurrentLink) &&
9749+
"must be a select recipe");
9750+
VecOp = CurrentLink->getOperand(0);
9751+
Kind = RecurKind::Or;
97419752
} else {
97429753
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) {
97439754
if (isa<VPWidenRecipe>(CurrentLink)) {
@@ -9902,10 +9913,17 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
99029913
// selected if the negated condition is true in any iteration.
99039914
if (Select->getOperand(1) == PhiR)
99049915
Cmp = Builder.createNot(Cmp);
9905-
VPValue *Or = Builder.createOr(PhiR, Cmp);
9906-
Select->getVPSingleValue()->replaceAllUsesWith(Or);
9907-
// Delete Select now that it has invalid types.
9908-
ToDelete.push_back(Select);
9916+
9917+
if (PhiR->isInLoop() && MinVF.isVector()) {
9918+
auto *Reduction = cast<VPReductionRecipe>(
9919+
*find_if(PhiR->users(), IsaPred<VPReductionRecipe>));
9920+
Reduction->setOperand(1, Cmp);
9921+
} else {
9922+
VPValue *Or = Builder.createOr(PhiR, Cmp);
9923+
Select->getVPSingleValue()->replaceAllUsesWith(Or);
9924+
// Delete Select now that it has invalid types.
9925+
ToDelete.push_back(Select);
9926+
}
99099927

99109928
// Convert the reduction phi to operate on bools.
99119929
PhiR->setOperand(0, Plan->getOrAddLiveIn(ConstantInt::getFalse(

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
668668

669669
// Create the reduction after the loop. Note that inloop reductions create
670670
// the target reduction in the loop using a Reduction recipe.
671-
if ((State.VF.isVector() ||
672-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
673-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
674-
!PhiR->isInLoop()) {
671+
if (((State.VF.isVector() ||
672+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) &&
673+
!PhiR->isInLoop()) ||
674+
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
675675
// TODO: Support in-order reductions based on the recurrence descriptor.
676676
// All ops in the reduction inherit fast-math-flags from the recurrence
677677
// descriptor.
@@ -2302,7 +2302,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
23022302
Value *PrevInChain = State.get(getChainOp(), /*IsScalar*/ true);
23032303
RecurKind Kind = getRecurrenceKind();
23042304
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
2305-
"In-loop AnyOf reductions aren't currently supported");
2305+
"In-loop AnyOf reduction should use Or reduction recipe");
23062306
// Propagate the fast-math flags carried by the underlying instruction.
23072307
IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
23082308
State.Builder.setFastMathFlags(getFastMathFlags());

llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-inloop-reduction.ll

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,23 +1924,22 @@ define i32 @anyof_icmp(ptr %a, i64 %n, i32 %start, i32 %inv) {
19241924
; IF-EVL: vector.body:
19251925
; IF-EVL-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
19261926
; IF-EVL-NEXT: [[EVL_BASED_IV:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
1927-
; IF-EVL-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i1> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
1927+
; IF-EVL-NEXT: [[VEC_PHI:%.*]] = phi i1 [ false, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
19281928
; IF-EVL-NEXT: [[TMP9:%.*]] = sub i64 [[N]], [[EVL_BASED_IV]]
19291929
; IF-EVL-NEXT: [[TMP10:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[TMP9]], i32 4, i1 true)
19301930
; IF-EVL-NEXT: [[TMP11:%.*]] = add i64 [[EVL_BASED_IV]], 0
19311931
; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP11]]
19321932
; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr inbounds i32, ptr [[TMP12]], i32 0
19331933
; IF-EVL-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 4 x i32> @llvm.vp.load.nxv4i32.p0(ptr align 4 [[TMP13]], <vscale x 4 x i1> splat (i1 true), i32 [[TMP10]])
19341934
; IF-EVL-NEXT: [[TMP14:%.*]] = icmp slt <vscale x 4 x i32> [[VP_OP_LOAD]], splat (i32 3)
1935-
; IF-EVL-NEXT: [[TMP15:%.*]] = or <vscale x 4 x i1> [[VEC_PHI]], [[TMP14]]
1936-
; IF-EVL-NEXT: [[TMP16]] = call <vscale x 4 x i1> @llvm.vp.merge.nxv4i1(<vscale x 4 x i1> splat (i1 true), <vscale x 4 x i1> [[TMP15]], <vscale x 4 x i1> [[VEC_PHI]], i32 [[TMP10]])
1935+
; IF-EVL-NEXT: [[TMP15:%.*]] = call i1 @llvm.vp.reduce.or.nxv4i1(i1 false, <vscale x 4 x i1> [[TMP14]], <vscale x 4 x i1> splat (i1 true), i32 [[TMP10]])
1936+
; IF-EVL-NEXT: [[TMP19]] = or i1 [[TMP15]], [[VEC_PHI]]
19371937
; IF-EVL-NEXT: [[TMP17:%.*]] = zext i32 [[TMP10]] to i64
19381938
; IF-EVL-NEXT: [[INDEX_EVL_NEXT]] = add i64 [[TMP17]], [[EVL_BASED_IV]]
19391939
; IF-EVL-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP8]]
19401940
; IF-EVL-NEXT: [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
19411941
; IF-EVL-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP34:![0-9]+]]
19421942
; IF-EVL: middle.block:
1943-
; IF-EVL-NEXT: [[TMP19:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP16]])
19441943
; IF-EVL-NEXT: [[TMP20:%.*]] = freeze i1 [[TMP19]]
19451944
; IF-EVL-NEXT: [[RDX_SELECT:%.*]] = select i1 [[TMP20]], i32 [[INV:%.*]], i32 [[START:%.*]]
19461945
; IF-EVL-NEXT: br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -1978,18 +1977,18 @@ define i32 @anyof_icmp(ptr %a, i64 %n, i32 %start, i32 %inv) {
19781977
; NO-VP-NEXT: br label [[VECTOR_BODY:%.*]]
19791978
; NO-VP: vector.body:
19801979
; NO-VP-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
1981-
; NO-VP-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i1> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
1980+
; NO-VP-NEXT: [[VEC_PHI:%.*]] = phi i1 [ false, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
19821981
; NO-VP-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
19831982
; NO-VP-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP6]]
19841983
; NO-VP-NEXT: [[TMP8:%.*]] = getelementptr inbounds i32, ptr [[TMP7]], i32 0
19851984
; NO-VP-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i32>, ptr [[TMP8]], align 4
19861985
; NO-VP-NEXT: [[TMP9:%.*]] = icmp slt <vscale x 4 x i32> [[WIDE_LOAD]], splat (i32 3)
1987-
; NO-VP-NEXT: [[TMP10]] = or <vscale x 4 x i1> [[VEC_PHI]], [[TMP9]]
1986+
; NO-VP-NEXT: [[TMP10:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP9]])
1987+
; NO-VP-NEXT: [[TMP12]] = or i1 [[TMP10]], [[VEC_PHI]]
19881988
; NO-VP-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
19891989
; NO-VP-NEXT: [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
19901990
; NO-VP-NEXT: br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP34:![0-9]+]]
19911991
; NO-VP: middle.block:
1992-
; NO-VP-NEXT: [[TMP12:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP10]])
19931992
; NO-VP-NEXT: [[TMP13:%.*]] = freeze i1 [[TMP12]]
19941993
; NO-VP-NEXT: [[RDX_SELECT:%.*]] = select i1 [[TMP13]], i32 [[INV:%.*]], i32 [[START:%.*]]
19951994
; NO-VP-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N]], [[N_VEC]]
@@ -2051,23 +2050,22 @@ define i32 @anyof_fcmp(ptr %a, i64 %n, i32 %start, i32 %inv) {
20512050
; IF-EVL: vector.body:
20522051
; IF-EVL-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
20532052
; IF-EVL-NEXT: [[EVL_BASED_IV:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
2054-
; IF-EVL-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i1> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
2053+
; IF-EVL-NEXT: [[VEC_PHI:%.*]] = phi i1 [ false, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
20552054
; IF-EVL-NEXT: [[TMP9:%.*]] = sub i64 [[N]], [[EVL_BASED_IV]]
20562055
; IF-EVL-NEXT: [[TMP10:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[TMP9]], i32 4, i1 true)
20572056
; IF-EVL-NEXT: [[TMP11:%.*]] = add i64 [[EVL_BASED_IV]], 0
20582057
; IF-EVL-NEXT: [[TMP12:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP11]]
20592058
; IF-EVL-NEXT: [[TMP13:%.*]] = getelementptr inbounds float, ptr [[TMP12]], i32 0
20602059
; IF-EVL-NEXT: [[VP_OP_LOAD:%.*]] = call <vscale x 4 x float> @llvm.vp.load.nxv4f32.p0(ptr align 4 [[TMP13]], <vscale x 4 x i1> splat (i1 true), i32 [[TMP10]])
20612060
; IF-EVL-NEXT: [[TMP14:%.*]] = fcmp fast olt <vscale x 4 x float> [[VP_OP_LOAD]], splat (float 3.000000e+00)
2062-
; IF-EVL-NEXT: [[TMP15:%.*]] = or <vscale x 4 x i1> [[VEC_PHI]], [[TMP14]]
2063-
; IF-EVL-NEXT: [[TMP16]] = call <vscale x 4 x i1> @llvm.vp.merge.nxv4i1(<vscale x 4 x i1> splat (i1 true), <vscale x 4 x i1> [[TMP15]], <vscale x 4 x i1> [[VEC_PHI]], i32 [[TMP10]])
2061+
; IF-EVL-NEXT: [[TMP15:%.*]] = call i1 @llvm.vp.reduce.or.nxv4i1(i1 false, <vscale x 4 x i1> [[TMP14]], <vscale x 4 x i1> splat (i1 true), i32 [[TMP10]])
2062+
; IF-EVL-NEXT: [[TMP19]] = or i1 [[TMP15]], [[VEC_PHI]]
20642063
; IF-EVL-NEXT: [[TMP17:%.*]] = zext i32 [[TMP10]] to i64
20652064
; IF-EVL-NEXT: [[INDEX_EVL_NEXT]] = add i64 [[TMP17]], [[EVL_BASED_IV]]
20662065
; IF-EVL-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP8]]
20672066
; IF-EVL-NEXT: [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
20682067
; IF-EVL-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP36:![0-9]+]]
20692068
; IF-EVL: middle.block:
2070-
; IF-EVL-NEXT: [[TMP19:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP16]])
20712069
; IF-EVL-NEXT: [[TMP20:%.*]] = freeze i1 [[TMP19]]
20722070
; IF-EVL-NEXT: [[RDX_SELECT:%.*]] = select i1 [[TMP20]], i32 [[INV:%.*]], i32 [[START:%.*]]
20732071
; IF-EVL-NEXT: br i1 true, label [[FOR_END:%.*]], label [[SCALAR_PH]]
@@ -2105,18 +2103,18 @@ define i32 @anyof_fcmp(ptr %a, i64 %n, i32 %start, i32 %inv) {
21052103
; NO-VP-NEXT: br label [[VECTOR_BODY:%.*]]
21062104
; NO-VP: vector.body:
21072105
; NO-VP-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
2108-
; NO-VP-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i1> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP10:%.*]], [[VECTOR_BODY]] ]
2106+
; NO-VP-NEXT: [[VEC_PHI:%.*]] = phi i1 [ false, [[VECTOR_PH]] ], [ [[TMP12:%.*]], [[VECTOR_BODY]] ]
21092107
; NO-VP-NEXT: [[TMP6:%.*]] = add i64 [[INDEX]], 0
21102108
; NO-VP-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[TMP6]]
21112109
; NO-VP-NEXT: [[TMP8:%.*]] = getelementptr inbounds float, ptr [[TMP7]], i32 0
21122110
; NO-VP-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[TMP8]], align 4
21132111
; NO-VP-NEXT: [[TMP9:%.*]] = fcmp fast olt <vscale x 4 x float> [[WIDE_LOAD]], splat (float 3.000000e+00)
2114-
; NO-VP-NEXT: [[TMP10]] = or <vscale x 4 x i1> [[VEC_PHI]], [[TMP9]]
2112+
; NO-VP-NEXT: [[TMP10:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP9]])
2113+
; NO-VP-NEXT: [[TMP12]] = or i1 [[TMP10]], [[VEC_PHI]]
21152114
; NO-VP-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
21162115
; NO-VP-NEXT: [[TMP11:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
21172116
; NO-VP-NEXT: br i1 [[TMP11]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP36:![0-9]+]]
21182117
; NO-VP: middle.block:
2119-
; NO-VP-NEXT: [[TMP12:%.*]] = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> [[TMP10]])
21202118
; NO-VP-NEXT: [[TMP13:%.*]] = freeze i1 [[TMP12]]
21212119
; NO-VP-NEXT: [[RDX_SELECT:%.*]] = select i1 [[TMP13]], i32 [[INV:%.*]], i32 [[START:%.*]]
21222120
; NO-VP-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N]], [[N_VEC]]

0 commit comments

Comments
 (0)