-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LV] Reduce register usage for scaled reductions #133090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a994bf2
82151a8
f5d5ed6
54af7d2
b6b9063
a695124
f05ebee
ff432c9
0f68427
8427362
b905806
5f14165
296e3ce
c38fcab
e6061f2
f1e7e9b
bde39b4
f58b5e1
aefea41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4874,6 +4874,20 @@ void LoopVectorizationCostModel::collectElementTypesForWidening() { | |||||
} | ||||||
} | ||||||
|
||||||
/// Get the VF scaling factor applied to the recipe's output, if the recipe has | ||||||
/// one. | ||||||
static unsigned getVFScaleFactor(VPRecipeBase *R) { | ||||||
if (isa<VPPartialReductionRecipe, VPReductionPHIRecipe>(R)) { | ||||||
auto *ReductionR = dyn_cast<VPReductionPHIRecipe>(R); | ||||||
auto *PartialReductionR = | ||||||
ReductionR ? nullptr : dyn_cast<VPPartialReductionRecipe>(R); | ||||||
unsigned ScaleFactor = ReductionR ? ReductionR->getVFScaleFactor() | ||||||
: PartialReductionR->getVFScaleFactor(); | ||||||
return ScaleFactor; | ||||||
} | ||||||
return 1; | ||||||
} | ||||||
|
||||||
/// Estimate the register usage for \p Plan and vectorization factors in \p VFs | ||||||
/// by calculating the highest number of values that are live at a single | ||||||
/// location as a rough estimate. Returns the register usage for each VF in \p | ||||||
|
@@ -5014,7 +5028,6 @@ calculateRegisterUsage(VPlan &Plan, ArrayRef<ElementCount> VFs, | |||||
if (isa<VPVectorPointerRecipe, VPVectorEndPointerRecipe, | ||||||
VPBranchOnMaskRecipe>(R)) | ||||||
continue; | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unrelated change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
if (VFs[J].isScalar() || | ||||||
isa<VPCanonicalIVPHIRecipe, VPReplicateRecipe, VPDerivedIVRecipe, | ||||||
VPScalarIVStepsRecipe>(R) || | ||||||
|
@@ -5028,10 +5041,19 @@ calculateRegisterUsage(VPlan &Plan, ArrayRef<ElementCount> VFs, | |||||
// even in the scalar case. | ||||||
RegUsage[ClassID] += 1; | ||||||
} else { | ||||||
// The output from scaled phis and scaled reductions actually has | ||||||
// fewer lanes than the VF. | ||||||
unsigned ScaleFactor = getVFScaleFactor(R); | ||||||
ElementCount VF = VFs[J].divideCoefficientBy(ScaleFactor); | ||||||
LLVM_DEBUG(if (VF != VFs[J]) { | ||||||
dbgs() << "LV(REG): Scaled down VF from " << VFs[J] << " to " << VF | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: for the tests that check the debug output, can a check for this line be added? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately the VF that has partial reductions is pruned before the register usage is calculated. I can add one as part of #132190 once this is merged, though. |
||||||
<< " for " << *R << "\n"; | ||||||
}); | ||||||
|
||||||
for (VPValue *DefV : R->definedValues()) { | ||||||
Type *ScalarTy = TypeInfo.inferScalarType(DefV); | ||||||
unsigned ClassID = TTI.getRegisterClassForType(true, ScalarTy); | ||||||
RegUsage[ClassID] += GetRegUsage(ScalarTy, VFs[J]); | ||||||
RegUsage[ClassID] += GetRegUsage(ScalarTy, VF); | ||||||
} | ||||||
} | ||||||
} | ||||||
|
@@ -9137,8 +9159,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe( | |||||
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) | ||||||
return tryToWidenMemory(Instr, Operands, Range); | ||||||
|
||||||
if (getScalingForReduction(Instr)) | ||||||
return tryToCreatePartialReduction(Instr, Operands); | ||||||
if (auto ScaleFactor = getScalingForReduction(Instr)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value()); | ||||||
|
||||||
if (!shouldWiden(Instr, Range)) | ||||||
return nullptr; | ||||||
|
@@ -9162,7 +9184,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe( | |||||
|
||||||
VPRecipeBase * | ||||||
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction, | ||||||
ArrayRef<VPValue *> Operands) { | ||||||
ArrayRef<VPValue *> Operands, | ||||||
unsigned ScaleFactor) { | ||||||
assert(Operands.size() == 2 && | ||||||
"Unexpected number of operands for partial reduction"); | ||||||
|
||||||
|
@@ -9195,7 +9218,7 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction, | |||||
BinOp = Builder.createSelect(Mask, BinOp, Zero, Reduction->getDebugLoc()); | ||||||
} | ||||||
return new VPPartialReductionRecipe(ReductionOpcode, BinOp, Accumulator, | ||||||
Reduction); | ||||||
ScaleFactor, Reduction); | ||||||
} | ||||||
|
||||||
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2033,6 +2033,9 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe, | |||||
/// Generate the phi/select nodes. | ||||||
void execute(VPTransformState &State) override; | ||||||
|
||||||
/// Get the factor that the VF of this recipe's output should be scaled by | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
unsigned getVFScaleFactor() const { return VFScaleFactor; } | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps good to have comments on both new functions added? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
|
||||||
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) | ||||||
/// Print the recipe. | ||||||
void print(raw_ostream &O, const Twine &Indent, | ||||||
|
@@ -2063,17 +2066,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe, | |||||
/// scalar value. | ||||||
class VPPartialReductionRecipe : public VPSingleDefRecipe { | ||||||
unsigned Opcode; | ||||||
unsigned VFScaleFactor; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wuld be good to add a comment here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||
|
||||||
public: | ||||||
VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0, | ||||||
VPValue *Op1) | ||||||
VPValue *Op1, unsigned VFScaleFactor) | ||||||
: VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1, | ||||||
ReductionInst) {} | ||||||
VFScaleFactor, ReductionInst) {} | ||||||
VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1, | ||||||
unsigned VFScaleFactor, | ||||||
Instruction *ReductionInst = nullptr) | ||||||
: VPSingleDefRecipe(VPDef::VPPartialReductionSC, | ||||||
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst), | ||||||
Opcode(Opcode) { | ||||||
Opcode(Opcode), VFScaleFactor(VFScaleFactor) { | ||||||
[[maybe_unused]] auto *AccumulatorRecipe = | ||||||
getOperand(1)->getDefiningRecipe(); | ||||||
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) || | ||||||
|
@@ -2084,7 +2089,7 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe { | |||||
|
||||||
VPPartialReductionRecipe *clone() override { | ||||||
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1), | ||||||
getUnderlyingInstr()); | ||||||
VFScaleFactor, getUnderlyingInstr()); | ||||||
} | ||||||
|
||||||
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC) | ||||||
|
@@ -2099,6 +2104,9 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe { | |||||
/// Get the binary op's opcode. | ||||||
unsigned getOpcode() const { return Opcode; } | ||||||
|
||||||
/// Get the factor that the VF of this recipe's output should be scaled by | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
unsigned getVFScaleFactor() const { return VFScaleFactor; } | ||||||
|
||||||
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) | ||||||
/// Print the recipe. | ||||||
void print(raw_ostream &O, const Twine &Indent, | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -516,8 +516,8 @@ VPInstruction *VPlanSlp::buildGraph(ArrayRef<VPValue *> Values) { | |
auto *Inst = cast<VPInstruction>(Values[0])->getUnderlyingInstr(); | ||
auto *VPI = new VPInstruction(Opcode, CombinedOperands, Inst->getDebugLoc()); | ||
|
||
LLVM_DEBUG(dbgs() << "Create VPInstruction " << *VPI << " " | ||
<< *cast<VPInstruction>(Values[0]) << "\n"); | ||
LLVM_DEBUG(dbgs() << "Create VPInstruction " << cast<VPValue>(*VPI) << " " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems unfortunate, due to VPI being both a VPValue & VPInstruction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly, it was an ambiguous call to |
||
<< Values[0] << "\n"); | ||
addCombined(Values, VPI); | ||
return VPI; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I missing something, or can this just be:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.