Skip to content

Commit e15a937

Browse files
committed
[LoopVectorizer] Bundle partial reductions with different extensions
This PR adds support for extensions of different signedness to VPMulAccumulateReductionRecipe and allows such partial reductions to be bundled into that class.
1 parent f47a24c commit e15a937

File tree

5 files changed

+173
-158
lines changed

5 files changed

+173
-158
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,11 +2689,13 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
26892689
/// and needs to be lowered to concrete recipes before codegen. The operands are
26902690
/// {ChainOp, VecOp1, VecOp2, [Condition]}.
26912691
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2692-
/// Opcode of the extend for VecOp1 and VecOp2.
2693-
Instruction::CastOps ExtOp;
2692+
/// Opcodes of the extend recipes.
2693+
Instruction::CastOps ExtOp0;
2694+
Instruction::CastOps ExtOp1;
26942695

2695-
/// Non-neg flag of the extend recipe.
2696-
bool IsNonNeg = false;
2696+
/// Non-neg flags of the extend recipe.
2697+
bool IsNonNeg0 = false;
2698+
bool IsNonNeg1 = false;
26972699

26982700
/// The scalar type after extending.
26992701
Type *ResultTy = nullptr;
@@ -2710,7 +2712,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27102712
MulAcc->getCondOp(), MulAcc->isOrdered(),
27112713
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
27122714
MulAcc->getDebugLoc()),
2713-
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2715+
ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
2716+
IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
27142717
ResultTy(MulAcc->getResultType()),
27152718
VFScaleFactor(MulAcc->getVFScaleFactor()) {
27162719
transferFlags(*MulAcc);
@@ -2728,19 +2731,23 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27282731
R->getCondOp(), R->isOrdered(),
27292732
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27302733
R->getDebugLoc()),
2731-
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy),
2734+
ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
2735+
IsNonNeg0(Ext0->hasNonNegFlag() && Ext0->isNonNeg()), IsNonNeg1(Ext1->hasNonNegFlag() && Ext1->isNonNeg()),
2736+
ResultTy(ResultTy),
27322737
VFScaleFactor(ScaleFactor) {
27332738
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27342739
Instruction::Add &&
27352740
"The reduction instruction in MulAccumulateteReductionRecipe must "
27362741
"be Add");
2737-
assert((ExtOp == Instruction::CastOps::ZExt ||
2738-
ExtOp == Instruction::CastOps::SExt) &&
2742+
assert(((ExtOp0 == Instruction::CastOps::ZExt ||
2743+
ExtOp0 == Instruction::CastOps::SExt) && (ExtOp1 == Instruction::CastOps::ZExt || ExtOp1 == Instruction::CastOps::SExt)) &&
27392744
"VPMulAccumulateReductionRecipe only supports zext and sext.");
27402745
setUnderlyingValue(R->getUnderlyingValue());
27412746
// Only set the non-negative flag if the original recipe contains.
27422747
if (Ext0->hasNonNegFlag())
2743-
IsNonNeg = Ext0->isNonNeg();
2748+
IsNonNeg0 = Ext0->isNonNeg();
2749+
if (Ext1->hasNonNegFlag())
2750+
IsNonNeg1 = Ext1->isNonNeg();
27442751
}
27452752

27462753
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2751,7 +2758,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27512758
R->getCondOp(), R->isOrdered(),
27522759
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27532760
R->getDebugLoc()),
2754-
ExtOp(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) {
2761+
ExtOp0(Instruction::CastOps::CastOpsEnd),
2762+
ExtOp1(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) {
27552763
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27562764
Instruction::Add &&
27572765
"The reduction instruction in MulAccumulateReductionRecipe must be "
@@ -2792,16 +2800,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27922800
VPValue *getVecOp1() const { return getOperand(2); }
27932801

27942802
/// Return true if this recipe contains extended operands.
2795-
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2803+
bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
2804+
2805+
/// Return if the operands of mul instruction come from same extend.
2806+
bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
27962807

27972808
/// Return the opcode of the extends for the operands.
2798-
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2809+
Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
2810+
Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
2811+
2812+
/// Return if the first extend's opcode is ZExt.
2813+
bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
2814+
2815+
/// Return if the second extend's opcode is ZExt.
2816+
bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
27992817

2800-
/// Return if the operands are zero-extended.
2801-
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2818+
/// Return true if the first operand extend has the non-negative flag.
2819+
bool isNonNeg0() const { return IsNonNeg0; }
28022820

2803-
/// Return true if the operand extends have the non-negative flag.
2804-
bool isNonNeg() const { return IsNonNeg; }
2821+
/// Return true if the second operand extend has the non-negative flag.
2822+
bool isNonNeg1() const { return IsNonNeg1; }
28052823

28062824
/// Return the scaling factor that the VF is divided by to form the recipe's
28072825
/// output

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,14 +2576,14 @@ VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
25762576
return Ctx.TTI.getPartialReductionCost(
25772577
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
25782578
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
2579-
TTI::getPartialReductionExtendKind(getExtOpcode()),
2580-
TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
2579+
TTI::getPartialReductionExtendKind(getExt0Opcode()),
2580+
TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
25812581
}
25822582

25832583
Type *RedTy = Ctx.Types.inferScalarType(this);
25842584
auto *SrcVecTy =
25852585
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2586-
return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
2586+
return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
25872587
Ctx.CostKind);
25882588
}
25892589

@@ -2669,15 +2669,16 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26692669
if (isExtended())
26702670
O << "(";
26712671
getVecOp0()->printAsOperand(O, SlotTracker);
2672-
if (isExtended())
2673-
O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType()
2672+
if (isExtended()) {
2673+
O << " " << Instruction::getOpcodeName(ExtOp0) << " to " << *getResultType()
26742674
<< "), (";
2675-
else
2675+
} else
26762676
O << ", ";
26772677
getVecOp1()->printAsOperand(O, SlotTracker);
2678-
if (isExtended())
2679-
O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType()
2678+
if (isExtended()) {
2679+
O << " " << Instruction::getOpcodeName(ExtOp1) << " to " << *getResultType()
26802680
<< ")";
2681+
}
26812682
if (isConditional()) {
26822683
O << ", ";
26832684
getCondOp()->printAsOperand(O, SlotTracker);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,27 +2546,28 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
25462546
VPValue *Op0, *Op1;
25472547
if (MulAcc->isExtended()) {
25482548
Type *RedTy = MulAcc->getResultType();
2549-
if (MulAcc->isZExt())
2550-
Op0 = new VPWidenCastRecipe(
2551-
MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
2552-
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()), MulAcc->getDebugLoc());
2549+
if (MulAcc->isZExt0())
2550+
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
2551+
RedTy, VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg0()),
2552+
MulAcc->getDebugLoc());
25532553
else
2554-
Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
2554+
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
25552555
RedTy, {}, MulAcc->getDebugLoc());
25562556
Op0->getDefiningRecipe()->insertBefore(MulAcc);
25572557
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
25582558
// VPWidenCastRecipe.
25592559
if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
25602560
Op1 = Op0;
25612561
} else {
2562-
if (MulAcc->isZExt())
2563-
Op1 = new VPWidenCastRecipe(
2564-
MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
2565-
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()),
2566-
MulAcc->getDebugLoc());
2562+
if (MulAcc->isZExt1())
2563+
Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
2564+
MulAcc->getVecOp1(), RedTy,
2565+
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg1()),
2566+
MulAcc->getDebugLoc());
25672567
else
2568-
Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
2569-
RedTy, {}, MulAcc->getDebugLoc());
2568+
Op1 =
2569+
new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
2570+
RedTy, {}, MulAcc->getDebugLoc());
25702571
Op1->getDefiningRecipe()->insertBefore(MulAcc);
25712572
}
25722573
} else {
@@ -2933,10 +2934,8 @@ tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
29332934
auto *BinOpR = cast<VPWidenRecipe>(BinOp->getDefiningRecipe());
29342935
VPWidenCastRecipe *Ext0R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(0));
29352936
VPWidenCastRecipe *Ext1R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(1));
2936-
2937-
// TODO: Make work with extends of different signedness
2938-
if (Ext0R->hasMoreThanOneUniqueUser() || Ext1R->hasMoreThanOneUniqueUser() ||
2939-
Ext0R->getOpcode() != Ext1R->getOpcode())
2937+
if (!Ext0R || Ext0R->hasMoreThanOneUniqueUser() || !Ext1R ||
2938+
Ext1R->hasMoreThanOneUniqueUser())
29402939
return;
29412940

29422941
auto *AbstractR = new VPMulAccumulateReductionRecipe(

0 commit comments

Comments
 (0)