Skip to content

Commit d83028e

Browse files
NickGuy-ArmSamTebbs33
authored andcommitted
[LoopVectorizer] Add support for partial reductions
1 parent 8e1b49c commit d83028e

16 files changed

+3812
-31
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
211211
/// for IR-level transformations.
212212
class TargetTransformInfo {
213213
public:
214+
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };
215+
216+
/// Get the kind of extension that an instruction represents.
217+
static PartialReductionExtendKind
218+
getPartialReductionExtendKind(Instruction *I);
219+
214220
/// Construct a TTI object using a type implementing the \c Concept
215221
/// API below.
216222
///
@@ -1280,6 +1286,18 @@ class TargetTransformInfo {
12801286
/// \return if target want to issue a prefetch in address space \p AS.
12811287
bool shouldPrefetchAddressSpace(unsigned AS) const;
12821288

1289+
/// \return The cost of a partial reduction, which is a reduction from a
1290+
/// vector to another vector with fewer elements of larger size. They are
1291+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1292+
/// takes an accumulator and a binary operation operand that itself is fed by
1293+
/// two extends. An example of an operation that uses a partial reduction is a
1294+
/// dot product, which reduces a vector to another of 4 times fewer elements.
1295+
InstructionCost
1296+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
1297+
ElementCount VF, PartialReductionExtendKind OpAExtend,
1298+
PartialReductionExtendKind OpBExtend,
1299+
std::optional<unsigned> BinOp = std::nullopt) const;
1300+
12831301
/// \return The maximum interleave factor that any transform should try to
12841302
/// perform for this target. This number depends on the level of parallelism
12851303
/// and the number of execution units in the CPU.
@@ -2107,6 +2125,18 @@ class TargetTransformInfo::Concept {
21072125
/// \return if target want to issue a prefetch in address space \p AS.
21082126
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
21092127

2128+
/// \return The cost of a partial reduction, which is a reduction from a
2129+
/// vector to another vector with fewer elements of larger size. They are
2130+
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
2131+
/// takes an accumulator and a binary operation operand that itself is fed by
2132+
/// two extends. An example of an operation that uses a partial reduction is a
2133+
/// dot product, which reduces a vector to another of 4 times fewer elements.
2134+
virtual InstructionCost
2135+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
2136+
ElementCount VF, PartialReductionExtendKind OpAExtend,
2137+
PartialReductionExtendKind OpBExtend,
2138+
std::optional<unsigned> BinOp) const = 0;
2139+
21102140
virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
21112141
virtual InstructionCost getArithmeticInstrCost(
21122142
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
@@ -2786,6 +2816,15 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
27862816
return Impl.shouldPrefetchAddressSpace(AS);
27872817
}
27882818

2819+
InstructionCost getPartialReductionCost(
2820+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
2821+
PartialReductionExtendKind OpAExtend,
2822+
PartialReductionExtendKind OpBExtend,
2823+
std::optional<unsigned> BinOp = std::nullopt) const override {
2824+
return Impl.getPartialReductionCost(Opcode, InputType, AccumType, VF,
2825+
OpAExtend, OpBExtend, BinOp);
2826+
}
2827+
27892828
unsigned getMaxInterleaveFactor(ElementCount VF) override {
27902829
return Impl.getMaxInterleaveFactor(VF);
27912830
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
585585
bool enableWritePrefetching() const { return false; }
586586
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
587587

588+
InstructionCost
589+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
590+
ElementCount VF,
591+
TTI::PartialReductionExtendKind OpAExtend,
592+
TTI::PartialReductionExtendKind OpBExtend,
593+
std::optional<unsigned> BinOp = std::nullopt) const {
594+
return InstructionCost::getInvalid();
595+
}
596+
588597
unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
589598

590599
InstructionCost getArithmeticInstrCost(

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,14 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
863863
return TTIImpl->shouldPrefetchAddressSpace(AS);
864864
}
865865

866+
InstructionCost TargetTransformInfo::getPartialReductionCost(
867+
unsigned Opcode, Type *InputType, Type *AccumType, ElementCount VF,
868+
PartialReductionExtendKind OpAExtend, PartialReductionExtendKind OpBExtend,
869+
std::optional<unsigned> BinOp) const {
870+
return TTIImpl->getPartialReductionCost(Opcode, InputType, AccumType, VF,
871+
OpAExtend, OpBExtend, BinOp);
872+
}
873+
866874
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
867875
return TTIImpl->getMaxInterleaveFactor(VF);
868876
}
@@ -974,6 +982,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
974982
return Cost;
975983
}
976984

985+
TargetTransformInfo::PartialReductionExtendKind
986+
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
987+
if (isa<SExtInst>(I))
988+
return PR_SignExtend;
989+
if (isa<ZExtInst>(I))
990+
return PR_ZeroExtend;
991+
return PR_None;
992+
}
993+
977994
TTI::CastContextHint
978995
TargetTransformInfo::getCastContextHint(const Instruction *I) {
979996
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/CodeGen/BasicTTIImpl.h"
2424
#include "llvm/IR/Function.h"
2525
#include "llvm/IR/Intrinsics.h"
26+
#include "llvm/Support/InstructionCost.h"
2627
#include <cstdint>
2728
#include <optional>
2829

@@ -357,6 +358,61 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
357358
return BaseT::isLegalNTLoad(DataType, Alignment);
358359
}
359360

361+
InstructionCost
362+
getPartialReductionCost(unsigned Opcode, Type *InputType, Type *AccumType,
363+
ElementCount VF,
364+
TTI::PartialReductionExtendKind OpAExtend,
365+
TTI::PartialReductionExtendKind OpBExtend,
366+
std::optional<unsigned> BinOp) const {
367+
368+
InstructionCost Invalid = InstructionCost::getInvalid();
369+
InstructionCost Cost(TTI::TCC_Basic);
370+
371+
if (Opcode != Instruction::Add)
372+
return Invalid;
373+
374+
EVT InputEVT = EVT::getEVT(InputType);
375+
EVT AccumEVT = EVT::getEVT(AccumType);
376+
377+
if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
378+
return Invalid;
379+
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
380+
return Invalid;
381+
382+
if (InputEVT == MVT::i8) {
383+
switch (VF.getKnownMinValue()) {
384+
default:
385+
return Invalid;
386+
case 8:
387+
if (AccumEVT == MVT::i32)
388+
Cost *= 2;
389+
else if (AccumEVT != MVT::i64)
390+
return Invalid;
391+
break;
392+
case 16:
393+
if (AccumEVT == MVT::i64)
394+
Cost *= 2;
395+
else if (AccumEVT != MVT::i32)
396+
return Invalid;
397+
break;
398+
}
399+
} else if (InputEVT == MVT::i16) {
400+
// FIXME: Allow i32 accumulator but increase cost, as we would extend
401+
// it to i64.
402+
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
403+
return Invalid;
404+
} else
405+
return Invalid;
406+
407+
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None)
408+
return Invalid;
409+
410+
if (!BinOp || (*BinOp) != Instruction::Mul)
411+
return Invalid;
412+
413+
return Cost;
414+
}
415+
360416
bool enableOrderedReductions() const { return true; }
361417

362418
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7532,6 +7532,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
75327532
}
75337533
continue;
75347534
}
7535+
// The VPlan-based cost model is more accurate for partial reduction and
7536+
// comparing against the legacy cost isn't desirable.
7537+
if (isa<VPPartialReductionRecipe>(&R))
7538+
return true;
75357539
if (Instruction *UI = GetInstructionForCost(&R))
75367540
SeenInstrs.insert(UI);
75377541
}
@@ -8746,6 +8750,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87468750
return Recipe;
87478751
}
87488752

8753+
/// Find all possible partial reductions in the loop and track all of those that
8754+
/// are valid so recipes can be formed later.
8755+
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8756+
// Find all possible partial reductions.
8757+
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8758+
PartialReductionChains;
8759+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8760+
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8761+
getScaledReduction(Phi, RdxDesc, Range))
8762+
PartialReductionChains.push_back(*Pair);
8763+
8764+
// A partial reduction is invalid if any of its extends are used by
8765+
// something that isn't another partial reduction. This is because the
8766+
// extends are intended to be lowered along with the reduction itself.
8767+
8768+
// Build up a set of partial reduction bin ops for efficient use checking.
8769+
SmallSet<User *, 4> PartialReductionBinOps;
8770+
for (const auto &[PartialRdx, _] : PartialReductionChains)
8771+
PartialReductionBinOps.insert(PartialRdx.BinOp);
8772+
8773+
auto ExtendIsOnlyUsedByPartialReductions =
8774+
[&PartialReductionBinOps](Instruction *Extend) {
8775+
return all_of(Extend->users(), [&](const User *U) {
8776+
return PartialReductionBinOps.contains(U);
8777+
});
8778+
};
8779+
8780+
// Check if each use of a chain's two extends is a partial reduction
8781+
// and only add those that don't have non-partial reduction users.
8782+
for (auto Pair : PartialReductionChains) {
8783+
PartialReductionChain Chain = Pair.first;
8784+
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8785+
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8786+
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
8787+
}
8788+
}
8789+
8790+
std::optional<std::pair<PartialReductionChain, unsigned>>
8791+
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8792+
const RecurrenceDescriptor &Rdx,
8793+
VFRange &Range) {
8794+
// TODO: Allow scaling reductions when predicating. The select at
8795+
// the end of the loop chooses between the phi value and most recent
8796+
// reduction result, both of which have different VFs to the active lane
8797+
// mask when scaling.
8798+
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8799+
return std::nullopt;
8800+
8801+
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8802+
if (!Update)
8803+
return std::nullopt;
8804+
8805+
Value *Op = Update->getOperand(0);
8806+
if (Op == PHI)
8807+
Op = Update->getOperand(1);
8808+
8809+
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8810+
if (!BinOp || !BinOp->hasOneUse())
8811+
return std::nullopt;
8812+
8813+
using namespace llvm::PatternMatch;
8814+
Value *A, *B;
8815+
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8816+
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8817+
return std::nullopt;
8818+
8819+
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8820+
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8821+
8822+
// Check that the extends extend from the same type.
8823+
if (A->getType() != B->getType())
8824+
return std::nullopt;
8825+
8826+
TTI::PartialReductionExtendKind OpAExtend =
8827+
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8828+
TTI::PartialReductionExtendKind OpBExtend =
8829+
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8830+
8831+
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8832+
8833+
unsigned TargetScaleFactor =
8834+
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8835+
A->getType()->getPrimitiveSizeInBits());
8836+
8837+
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8838+
[&](ElementCount VF) {
8839+
InstructionCost Cost = TTI->getPartialReductionCost(
8840+
Update->getOpcode(), A->getType(), PHI->getType(), VF,
8841+
OpAExtend, OpBExtend, std::make_optional(BinOp->getOpcode()));
8842+
return Cost.isValid();
8843+
},
8844+
Range))
8845+
return std::make_pair(Chain, TargetScaleFactor);
8846+
8847+
return std::nullopt;
8848+
}
8849+
87498850
VPRecipeBase *
87508851
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87518852
ArrayRef<VPValue *> Operands,
@@ -8770,9 +8871,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87708871
Legal->getReductionVars().find(Phi)->second;
87718872
assert(RdxDesc.getRecurrenceStartValue() ==
87728873
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
8773-
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
8774-
CM.isInLoopReduction(Phi),
8775-
CM.useOrderedReductions(RdxDesc));
8874+
8875+
// If the PHI is used by a partial reduction, set the scale factor.
8876+
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8877+
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
8878+
unsigned ScaleFactor = Pair ? Pair->second : 1;
8879+
PhiRecipe = new VPReductionPHIRecipe(
8880+
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
8881+
CM.useOrderedReductions(RdxDesc), ScaleFactor);
87768882
} else {
87778883
// TODO: Currently fixed-order recurrences are modeled as chains of
87788884
// first-order recurrences. If there are no users of the intermediate
@@ -8804,6 +8910,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88048910
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88058911
return tryToWidenMemory(Instr, Operands, Range);
88068912

8913+
if (getScaledReductionForInstr(Instr))
8914+
return tryToCreatePartialReduction(Instr, Operands);
8915+
88078916
if (!shouldWiden(Instr, Range))
88088917
return nullptr;
88098918

@@ -8824,6 +8933,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88248933
return tryToWiden(Instr, Operands, VPBB);
88258934
}
88268935

8936+
VPRecipeBase *
8937+
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
8938+
ArrayRef<VPValue *> Operands) {
8939+
assert(Operands.size() == 2 &&
8940+
"Unexpected number of operands for partial reduction");
8941+
8942+
VPValue *BinOp = Operands[0];
8943+
VPValue *Phi = Operands[1];
8944+
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
8945+
std::swap(BinOp, Phi);
8946+
8947+
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
8948+
Reduction);
8949+
}
8950+
88278951
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
88288952
ElementCount MaxVF) {
88298953
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -9247,7 +9371,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92479371
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92489372
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);
92499373

9250-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9374+
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9375+
Builder);
92519376

92529377
// ---------------------------------------------------------------------------
92539378
// Pre-construction: record ingredients whose recipes we'll need to further
@@ -9293,6 +9418,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92939418
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
92949419
return Legal->blockNeedsPredication(BB) || NeedsBlends;
92959420
});
9421+
9422+
RecipeBuilder.collectScaledReductions(Range);
9423+
92969424
auto *MiddleVPBB = Plan->getMiddleBlock();
92979425
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
92989426
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {

0 commit comments

Comments
 (0)