Skip to content

Commit aa24029

Browse files
authored
[VPlan] Unroll VPReplicateRecipe by VF. (#142433)
Explicitly unroll VPReplicateRecipes outside replicate regions by VF, replacing them by VF single-scalar recipes. Extracts for operands are added as needed and the scalar results are combined to a vector using a new BuildVector VPInstruction. It also adds a few folds to simplify unnecessary extracts/BuildVectors. It also adds a BuildStructVector opcode for handling of calls that have struct return types. VPReplicateRecipe in replicate regions can will be unrolled as follow up, turing non-single-scalar VPReplicateRecipes into 'abstract', i.e. not executable. PR: #142433
1 parent 696c0f9 commit aa24029

18 files changed

+286
-146
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7328,6 +7328,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
73287328
// cost model is complete for better cost estimates.
73297329
VPlanTransforms::runPass(VPlanTransforms::unrollByUF, BestVPlan, BestUF,
73307330
OrigLoop->getHeader()->getContext());
7331+
VPlanTransforms::runPass(VPlanTransforms::replicateByVF, BestVPlan, BestVF);
73317332
VPlanTransforms::runPass(VPlanTransforms::materializeBroadcasts, BestVPlan);
73327333
if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
73337334
VPlanTransforms::runPass(VPlanTransforms::addBranchWeightToMiddleTerminator,

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,13 @@ Value *VPTransformState::get(const VPValue *Def, const VPLane &Lane) {
261261
return Data.VPV2Scalars[Def][0];
262262
}
263263

264+
// Look through BuildVector to avoid redundant extracts.
265+
// TODO: Remove once replicate regions are unrolled explicitly.
266+
if (Lane.getKind() == VPLane::Kind::First && match(Def, m_BuildVector())) {
267+
auto *BuildVector = cast<VPInstruction>(Def);
268+
return get(BuildVector->getOperand(Lane.getKnownLane()), true);
269+
}
270+
264271
assert(hasVectorValue(Def));
265272
auto *VecPart = Data.VPV2Vector[Def];
266273
if (!VecPart->getType()->isVectorTy()) {

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,13 @@ class VPInstruction : public VPRecipeWithIRFlags,
936936
BranchOnCount,
937937
BranchOnCond,
938938
Broadcast,
939+
/// Given operands of (the same) struct type, creates a struct of fixed-
940+
/// width vectors each containing a struct field of all operands. The
941+
/// number of operands matches the element count of every vector.
942+
BuildStructVector,
943+
/// Creates a fixed-width vector containing all operands. The number of
944+
/// operands matches the vector element count.
945+
BuildVector,
939946
ComputeAnyOfResult,
940947
ComputeFindLastIVResult,
941948
ComputeReductionResult,

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
108108
case VPInstruction::CalculateTripCountMinusVF:
109109
case VPInstruction::CanonicalIVIncrementForPart:
110110
case VPInstruction::AnyOf:
111+
case VPInstruction::BuildStructVector:
112+
case VPInstruction::BuildVector:
111113
return SetResultTyFromOp();
112114
case VPInstruction::FirstActiveLane:
113115
return Type::getIntNTy(Ctx, 64);

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ struct Recipe_match {
221221
}
222222

223223
bool match(const VPRecipeBase *R) const {
224+
if (std::tuple_size<Ops_t>::value == 0) {
225+
assert(Opcode == VPInstruction::BuildVector &&
226+
"can only match BuildVector with empty ops");
227+
auto *VPI = dyn_cast<VPInstruction>(R);
228+
return VPI && VPI->getOpcode() == VPInstruction::BuildVector;
229+
}
230+
224231
if ((!matchRecipeAndOpcode<RecipeTys>(R) && ...))
225232
return false;
226233

@@ -263,6 +270,10 @@ struct Recipe_match {
263270
}
264271
};
265272

273+
template <unsigned Opcode, typename... RecipeTys>
274+
using ZeroOpRecipe_match =
275+
Recipe_match<std::tuple<>, Opcode, false, RecipeTys...>;
276+
266277
template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
267278
using UnaryRecipe_match =
268279
Recipe_match<std::tuple<Op0_t>, Opcode, false, RecipeTys...>;
@@ -271,6 +282,9 @@ template <typename Op0_t, unsigned Opcode>
271282
using UnaryVPInstruction_match =
272283
UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
273284

285+
template <unsigned Opcode>
286+
using ZeroOpVPInstruction_match = ZeroOpRecipe_match<Opcode, VPInstruction>;
287+
274288
template <typename Op0_t, unsigned Opcode>
275289
using AllUnaryRecipe_match =
276290
UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
@@ -302,6 +316,12 @@ using AllBinaryRecipe_match =
302316
BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe,
303317
VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>;
304318

319+
/// BuildVector is matches only its opcode, w/o matching its operands as the
320+
/// number of operands is not fixed.
321+
inline ZeroOpVPInstruction_match<VPInstruction::BuildVector> m_BuildVector() {
322+
return ZeroOpVPInstruction_match<VPInstruction::BuildVector>();
323+
}
324+
305325
template <unsigned Opcode, typename Op0_t>
306326
inline UnaryVPInstruction_match<Op0_t, Opcode>
307327
m_VPInstruction(const Op0_t &Op0) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
551551
}
552552
case Instruction::ExtractElement: {
553553
assert(State.VF.isVector() && "Only extract elements from vectors");
554+
if (getOperand(1)->isLiveIn()) {
555+
unsigned IdxToExtract =
556+
cast<ConstantInt>(getOperand(1)->getLiveInIRValue())->getZExtValue();
557+
return State.get(getOperand(0), VPLane(IdxToExtract));
558+
}
554559
Value *Vec = State.get(getOperand(0));
555560
Value *Idx = State.get(getOperand(1), /*IsScalar=*/true);
556561
return Builder.CreateExtractElement(Vec, Idx, Name);
@@ -664,6 +669,34 @@ Value *VPInstruction::generate(VPTransformState &State) {
664669
return Builder.CreateVectorSplat(
665670
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
666671
}
672+
case VPInstruction::BuildStructVector: {
673+
// For struct types, we need to build a new 'wide' struct type, where each
674+
// element is widened, i.e., we create a struct of vectors.
675+
auto *StructTy =
676+
cast<StructType>(State.TypeAnalysis.inferScalarType(getOperand(0)));
677+
Value *Res = PoisonValue::get(toVectorizedTy(StructTy, State.VF));
678+
for (const auto &[LaneIndex, Op] : enumerate(operands())) {
679+
for (unsigned FieldIndex = 0; FieldIndex != StructTy->getNumElements();
680+
FieldIndex++) {
681+
Value *ScalarValue =
682+
Builder.CreateExtractValue(State.get(Op, true), FieldIndex);
683+
Value *VectorValue = Builder.CreateExtractValue(Res, FieldIndex);
684+
VectorValue =
685+
Builder.CreateInsertElement(VectorValue, ScalarValue, LaneIndex);
686+
Res = Builder.CreateInsertValue(Res, VectorValue, FieldIndex);
687+
}
688+
}
689+
return Res;
690+
}
691+
case VPInstruction::BuildVector: {
692+
auto *ScalarTy = State.TypeAnalysis.inferScalarType(getOperand(0));
693+
auto NumOfElements = ElementCount::getFixed(getNumOperands());
694+
Value *Res = PoisonValue::get(toVectorizedTy(ScalarTy, NumOfElements));
695+
for (const auto &[Idx, Op] : enumerate(operands()))
696+
Res = State.Builder.CreateInsertElement(Res, State.get(Op, true),
697+
State.Builder.getInt32(Idx));
698+
return Res;
699+
}
667700
case VPInstruction::ReductionStartVector: {
668701
if (State.VF.isScalar())
669702
return State.get(getOperand(0), true);
@@ -953,10 +986,11 @@ void VPInstruction::execute(VPTransformState &State) {
953986
if (!hasResult())
954987
return;
955988
assert(GeneratedValue && "generate must produce a value");
956-
assert(
957-
(GeneratedValue->getType()->isVectorTy() == !GeneratesPerFirstLaneOnly ||
958-
State.VF.isScalar()) &&
959-
"scalar value but not only first lane defined");
989+
assert((((GeneratedValue->getType()->isVectorTy() ||
990+
GeneratedValue->getType()->isStructTy()) ==
991+
!GeneratesPerFirstLaneOnly) ||
992+
State.VF.isScalar()) &&
993+
"scalar value but not only first lane defined");
960994
State.set(this, GeneratedValue,
961995
/*IsScalar*/ GeneratesPerFirstLaneOnly);
962996
}
@@ -970,6 +1004,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
9701004
case Instruction::ICmp:
9711005
case Instruction::Select:
9721006
case VPInstruction::AnyOf:
1007+
case VPInstruction::BuildStructVector:
1008+
case VPInstruction::BuildVector:
9731009
case VPInstruction::CalculateTripCountMinusVF:
9741010
case VPInstruction::CanonicalIVIncrementForPart:
9751011
case VPInstruction::ExtractLastElement:
@@ -1092,6 +1128,12 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
10921128
case VPInstruction::Broadcast:
10931129
O << "broadcast";
10941130
break;
1131+
case VPInstruction::BuildStructVector:
1132+
O << "buildstructvector";
1133+
break;
1134+
case VPInstruction::BuildVector:
1135+
O << "buildvector";
1136+
break;
10951137
case VPInstruction::ExtractLastElement:
10961138
O << "extract-last-element";
10971139
break;
@@ -2686,45 +2728,27 @@ static void scalarizeInstruction(const Instruction *Instr,
26862728

26872729
void VPReplicateRecipe::execute(VPTransformState &State) {
26882730
Instruction *UI = getUnderlyingInstr();
2689-
if (State.Lane) { // Generate a single instance.
2690-
assert((State.VF.isScalar() || !isSingleScalar()) &&
2691-
"uniform recipe shouldn't be predicated");
2692-
assert(!State.VF.isScalable() && "Can't scalarize a scalable vector");
2693-
scalarizeInstruction(UI, this, *State.Lane, State);
2694-
// Insert scalar instance packing it into a vector.
2695-
if (State.VF.isVector() && shouldPack()) {
2696-
Value *WideValue;
2697-
// If we're constructing lane 0, initialize to start from poison.
2698-
if (State.Lane->isFirstLane()) {
2699-
assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
2700-
WideValue = PoisonValue::get(VectorType::get(UI->getType(), State.VF));
2701-
} else {
2702-
WideValue = State.get(this);
2703-
}
2704-
State.set(this, State.packScalarIntoVectorizedValue(this, WideValue,
2705-
*State.Lane));
2706-
}
2707-
return;
2708-
}
27092731

2710-
if (IsSingleScalar) {
2711-
// Uniform within VL means we need to generate lane 0.
2732+
if (!State.Lane) {
2733+
assert(IsSingleScalar && "VPReplicateRecipes outside replicate regions "
2734+
"must have already been unrolled");
27122735
scalarizeInstruction(UI, this, VPLane(0), State);
27132736
return;
27142737
}
27152738

2716-
// A store of a loop varying value to a uniform address only needs the last
2717-
// copy of the store.
2718-
if (isa<StoreInst>(UI) && vputils::isSingleScalar(getOperand(1))) {
2719-
auto Lane = VPLane::getLastLaneForVF(State.VF);
2720-
scalarizeInstruction(UI, this, VPLane(Lane), State);
2721-
return;
2739+
assert((State.VF.isScalar() || !isSingleScalar()) &&
2740+
"uniform recipe shouldn't be predicated");
2741+
assert(!State.VF.isScalable() && "Can't scalarize a scalable vector");
2742+
scalarizeInstruction(UI, this, *State.Lane, State);
2743+
// Insert scalar instance packing it into a vector.
2744+
if (State.VF.isVector() && shouldPack()) {
2745+
Value *WideValue =
2746+
State.Lane->isFirstLane()
2747+
? PoisonValue::get(VectorType::get(UI->getType(), State.VF))
2748+
: State.get(this);
2749+
State.set(this, State.packScalarIntoVectorizedValue(this, WideValue,
2750+
*State.Lane));
27222751
}
2723-
2724-
// Generate scalar instances for all VF lanes.
2725-
const unsigned EndLane = State.VF.getFixedValue();
2726-
for (unsigned Lane = 0; Lane < EndLane; ++Lane)
2727-
scalarizeInstruction(UI, this, VPLane(Lane), State);
27282752
}
27292753

27302754
bool VPReplicateRecipe::shouldPack() const {

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,24 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
11401140
return;
11411141
}
11421142

1143+
// Look through ExtractLastElement (BuildVector ....).
1144+
if (match(&R, m_VPInstruction<VPInstruction::ExtractLastElement>(
1145+
m_BuildVector()))) {
1146+
auto *BuildVector = cast<VPInstruction>(R.getOperand(0));
1147+
Def->replaceAllUsesWith(
1148+
BuildVector->getOperand(BuildVector->getNumOperands() - 1));
1149+
return;
1150+
}
1151+
1152+
// Look through ExtractPenultimateElement (BuildVector ....).
1153+
if (match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
1154+
m_BuildVector()))) {
1155+
auto *BuildVector = cast<VPInstruction>(R.getOperand(0));
1156+
Def->replaceAllUsesWith(
1157+
BuildVector->getOperand(BuildVector->getNumOperands() - 2));
1158+
return;
1159+
}
1160+
11431161
// Some simplifications can only be applied after unrolling. Perform them
11441162
// below.
11451163
if (!Plan->isUnrolled())

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ struct VPlanTransforms {
9999
/// Explicitly unroll \p Plan by \p UF.
100100
static void unrollByUF(VPlan &Plan, unsigned UF, LLVMContext &Ctx);
101101

102+
/// Replace each VPReplicateRecipe outside on any replicate region in \p Plan
103+
/// with \p VF single-scalar recipes.
104+
/// TODO: Also replicate VPReplicateRecipes inside replicate regions, thereby
105+
/// dissolving the latter.
106+
static void replicateByVF(VPlan &Plan, ElementCount VF);
107+
102108
/// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the
103109
/// resulting plan to \p BestVF and \p BestUF.
104110
static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF,

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "VPlan.h"
1616
#include "VPlanAnalysis.h"
1717
#include "VPlanCFG.h"
18+
#include "VPlanHelpers.h"
1819
#include "VPlanPatternMatch.h"
1920
#include "VPlanTransforms.h"
2021
#include "VPlanUtils.h"
@@ -450,3 +451,87 @@ void VPlanTransforms::unrollByUF(VPlan &Plan, unsigned UF, LLVMContext &Ctx) {
450451

451452
VPlanTransforms::removeDeadRecipes(Plan);
452453
}
454+
455+
/// Create a single-scalar clone of \p RepR for lane \p Lane.
456+
static VPReplicateRecipe *cloneForLane(VPlan &Plan, VPBuilder &Builder,
457+
Type *IdxTy, VPReplicateRecipe *RepR,
458+
VPLane Lane) {
459+
// Collect the operands at Lane, creating extracts as needed.
460+
SmallVector<VPValue *> NewOps;
461+
for (VPValue *Op : RepR->operands()) {
462+
if (vputils::isSingleScalar(Op)) {
463+
NewOps.push_back(Op);
464+
continue;
465+
}
466+
if (Lane.getKind() == VPLane::Kind::ScalableLast) {
467+
NewOps.push_back(
468+
Builder.createNaryOp(VPInstruction::ExtractLastElement, {Op}));
469+
continue;
470+
}
471+
// Look through buildvector to avoid unnecessary extracts.
472+
if (match(Op, m_BuildVector())) {
473+
NewOps.push_back(
474+
cast<VPInstruction>(Op)->getOperand(Lane.getKnownLane()));
475+
continue;
476+
}
477+
VPValue *Idx =
478+
Plan.getOrAddLiveIn(ConstantInt::get(IdxTy, Lane.getKnownLane()));
479+
VPValue *Ext = Builder.createNaryOp(Instruction::ExtractElement, {Op, Idx});
480+
NewOps.push_back(Ext);
481+
}
482+
483+
auto *New =
484+
new VPReplicateRecipe(RepR->getUnderlyingInstr(), NewOps,
485+
/*IsSingleScalar=*/true, /*Mask=*/nullptr, *RepR);
486+
New->insertBefore(RepR);
487+
return New;
488+
}
489+
490+
void VPlanTransforms::replicateByVF(VPlan &Plan, ElementCount VF) {
491+
Type *IdxTy = IntegerType::get(
492+
Plan.getScalarHeader()->getIRBasicBlock()->getContext(), 32);
493+
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
494+
vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) {
495+
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
496+
auto *RepR = dyn_cast<VPReplicateRecipe>(&R);
497+
if (!RepR || RepR->isSingleScalar())
498+
continue;
499+
500+
VPBuilder Builder(RepR);
501+
if (RepR->getNumUsers() == 0) {
502+
if (isa<StoreInst>(RepR->getUnderlyingInstr()) &&
503+
vputils::isSingleScalar(RepR->getOperand(1))) {
504+
// Stores to invariant addresses need to store the last lane only.
505+
cloneForLane(Plan, Builder, IdxTy, RepR,
506+
VPLane::getLastLaneForVF(VF));
507+
} else {
508+
// Create single-scalar version of RepR for all lanes.
509+
for (unsigned I = 0; I != VF.getKnownMinValue(); ++I)
510+
cloneForLane(Plan, Builder, IdxTy, RepR, VPLane(I));
511+
}
512+
RepR->eraseFromParent();
513+
continue;
514+
}
515+
/// Create single-scalar version of RepR for all lanes.
516+
SmallVector<VPValue *> LaneDefs;
517+
for (unsigned I = 0; I != VF.getKnownMinValue(); ++I)
518+
LaneDefs.push_back(cloneForLane(Plan, Builder, IdxTy, RepR, VPLane(I)));
519+
520+
/// Users that only demand the first lane can use the definition for lane
521+
/// 0.
522+
RepR->replaceUsesWithIf(LaneDefs[0], [RepR](VPUser &U, unsigned) {
523+
return U.onlyFirstLaneUsed(RepR);
524+
});
525+
526+
// If needed, create a Build(Struct)Vector recipe to insert the scalar
527+
// lane values into a vector.
528+
Type *ResTy = RepR->getUnderlyingInstr()->getType();
529+
VPValue *VecRes = Builder.createNaryOp(
530+
ResTy->isStructTy() ? VPInstruction::BuildStructVector
531+
: VPInstruction::BuildVector,
532+
LaneDefs);
533+
RepR->replaceAllUsesWith(VecRes);
534+
RepR->eraseFromParent();
535+
}
536+
}
537+
}

llvm/test/Transforms/LoopVectorize/X86/fixed-order-recurrence.ll

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,6 @@ define void @test_for_tried_to_force_scalar(ptr noalias %A, ptr noalias %B, ptr
393393
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <12 x float> [[WIDE_VEC]], <12 x float> poison, <4 x i32> <i32 0, i32 3, i32 6, i32 9>
394394
; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x float> [[STRIDED_VEC]], i32 3
395395
; CHECK-NEXT: store float [[TMP30]], ptr [[C:%.*]], align 4
396-
; CHECK-NEXT: [[TMP31:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 0
397-
; CHECK-NEXT: [[TMP38:%.*]] = load float, ptr [[TMP31]], align 4
398-
; CHECK-NEXT: [[TMP33:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 1
399-
; CHECK-NEXT: [[TMP32:%.*]] = load float, ptr [[TMP33]], align 4
400-
; CHECK-NEXT: [[TMP35:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 2
401-
; CHECK-NEXT: [[TMP34:%.*]] = load float, ptr [[TMP35]], align 4
402396
; CHECK-NEXT: [[TMP37:%.*]] = extractelement <4 x ptr> [[TMP29]], i32 3
403397
; CHECK-NEXT: [[TMP36:%.*]] = load float, ptr [[TMP37]], align 4
404398
; CHECK-NEXT: store float [[TMP36]], ptr [[B:%.*]], align 4

0 commit comments

Comments
 (0)