Skip to content

Commit 9dc6551

Browse files
[SLP][NFC]Extract a check for a SplitVectorize node, NFC
Reviewers: RKSimon, hiraditya Reviewed By: RKSimon Pull Request: #134896
1 parent 0c2a6f2 commit 9dc6551

File tree

1 file changed

+130
-109
lines changed

1 file changed

+130
-109
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 130 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -3125,6 +3125,18 @@ class BoUpSLP {
31253125
ArrayRef<Value *> VectorizedVals,
31263126
SmallPtrSetImpl<Value *> &CheckedExtracts);
31273127

3128+
/// Checks if it is legal and profitable to build SplitVectorize node for the
3129+
/// given \p VL.
3130+
/// \param Op1 first homogeneous scalars.
3131+
/// \param Op2 second homogeneous scalars.
3132+
/// \param ReorderIndices indices to reorder the scalars.
3133+
/// \returns true if the node was successfully built.
3134+
bool canBuildSplitNode(ArrayRef<Value *> VL,
3135+
const InstructionsState &LocalState,
3136+
SmallVectorImpl<Value *> &Op1,
3137+
SmallVectorImpl<Value *> &Op2,
3138+
OrdersType &ReorderIndices) const;
3139+
31283140
/// This is the recursive part of buildTree.
31293141
void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth,
31303142
const EdgeInfo &EI, unsigned InterleaveFactor = 0);
@@ -9169,6 +9181,117 @@ static bool tryToFindDuplicates(SmallVectorImpl<Value *> &VL,
91699181
return true;
91709182
}
91719183

9184+
bool BoUpSLP::canBuildSplitNode(ArrayRef<Value *> VL,
9185+
const InstructionsState &LocalState,
9186+
SmallVectorImpl<Value *> &Op1,
9187+
SmallVectorImpl<Value *> &Op2,
9188+
OrdersType &ReorderIndices) const {
9189+
constexpr unsigned SmallNodeSize = 4;
9190+
if (VL.size() <= SmallNodeSize || TTI->preferAlternateOpcodeVectorization() ||
9191+
!SplitAlternateInstructions)
9192+
return false;
9193+
9194+
ReorderIndices.assign(VL.size(), VL.size());
9195+
SmallBitVector Op1Indices(VL.size());
9196+
for (auto [Idx, V] : enumerate(VL)) {
9197+
auto *I = dyn_cast<Instruction>(V);
9198+
if (!I) {
9199+
Op1.push_back(V);
9200+
Op1Indices.set(Idx);
9201+
continue;
9202+
}
9203+
if ((LocalState.getAltOpcode() != LocalState.getOpcode() &&
9204+
I->getOpcode() == LocalState.getOpcode()) ||
9205+
(LocalState.getAltOpcode() == LocalState.getOpcode() &&
9206+
!isAlternateInstruction(I, LocalState.getMainOp(),
9207+
LocalState.getAltOp(), *TLI))) {
9208+
Op1.push_back(V);
9209+
Op1Indices.set(Idx);
9210+
continue;
9211+
}
9212+
Op2.push_back(V);
9213+
}
9214+
Type *ScalarTy = getValueType(VL.front());
9215+
VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
9216+
unsigned Opcode0 = LocalState.getOpcode();
9217+
unsigned Opcode1 = LocalState.getAltOpcode();
9218+
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
9219+
// Enable split node, only if all nodes do not form legal alternate
9220+
// instruction (like X86 addsub).
9221+
SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
9222+
SmallPtrSet<Value *, 4> UOp2(llvm::from_range, Op2);
9223+
if (UOp1.size() <= 1 || UOp2.size() <= 1 ||
9224+
TTI->isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask) ||
9225+
!hasFullVectorsOrPowerOf2(*TTI, Op1.front()->getType(), Op1.size()) ||
9226+
!hasFullVectorsOrPowerOf2(*TTI, Op2.front()->getType(), Op2.size()))
9227+
return false;
9228+
// Enable split node, only if all nodes are power-of-2/full registers.
9229+
unsigned Op1Cnt = 0, Op2Cnt = Op1.size();
9230+
for (unsigned Idx : seq<unsigned>(VL.size())) {
9231+
if (Op1Indices.test(Idx)) {
9232+
ReorderIndices[Op1Cnt] = Idx;
9233+
++Op1Cnt;
9234+
} else {
9235+
ReorderIndices[Op2Cnt] = Idx;
9236+
++Op2Cnt;
9237+
}
9238+
}
9239+
if (isIdentityOrder(ReorderIndices))
9240+
ReorderIndices.clear();
9241+
SmallVector<int> Mask;
9242+
if (!ReorderIndices.empty())
9243+
inversePermutation(ReorderIndices, Mask);
9244+
unsigned NumParts = TTI->getNumberOfParts(VecTy);
9245+
VectorType *Op1VecTy = getWidenedType(ScalarTy, Op1.size());
9246+
VectorType *Op2VecTy = getWidenedType(ScalarTy, Op2.size());
9247+
// Check non-profitable single register ops, which better to be represented
9248+
// as alternate ops.
9249+
if (NumParts >= VL.size())
9250+
return false;
9251+
if ((LocalState.getMainOp()->isBinaryOp() &&
9252+
LocalState.getAltOp()->isBinaryOp() &&
9253+
(LocalState.isShiftOp() || LocalState.isBitwiseLogicOp() ||
9254+
LocalState.isAddSubLikeOp() || LocalState.isMulDivLikeOp())) ||
9255+
(LocalState.getMainOp()->isCast() && LocalState.getAltOp()->isCast()) ||
9256+
(LocalState.getMainOp()->isUnaryOp() &&
9257+
LocalState.getAltOp()->isUnaryOp())) {
9258+
constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
9259+
InstructionCost InsertCost = ::getShuffleCost(
9260+
*TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
9261+
FixedVectorType *SubVecTy =
9262+
getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
9263+
InstructionCost NewShuffleCost =
9264+
::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
9265+
if (NumParts <= 1 && (Mask.empty() || InsertCost >= NewShuffleCost))
9266+
return false;
9267+
InstructionCost OriginalVecOpsCost =
9268+
TTI->getArithmeticInstrCost(Opcode0, VecTy, Kind) +
9269+
TTI->getArithmeticInstrCost(Opcode1, VecTy, Kind);
9270+
SmallVector<int> OriginalMask(VL.size(), PoisonMaskElem);
9271+
for (unsigned Idx : seq<unsigned>(VL.size())) {
9272+
if (isa<PoisonValue>(VL[Idx]))
9273+
continue;
9274+
OriginalMask[Idx] = Idx + (Op1Indices.test(Idx) ? 0 : VL.size());
9275+
}
9276+
InstructionCost OriginalCost =
9277+
OriginalVecOpsCost + ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc,
9278+
VecTy, OriginalMask, Kind);
9279+
InstructionCost NewVecOpsCost =
9280+
TTI->getArithmeticInstrCost(Opcode0, Op1VecTy, Kind) +
9281+
TTI->getArithmeticInstrCost(Opcode1, Op2VecTy, Kind);
9282+
InstructionCost NewCost =
9283+
NewVecOpsCost + InsertCost +
9284+
(!VectorizableTree.empty() && VectorizableTree.front()->hasState() &&
9285+
VectorizableTree.front()->getOpcode() == Instruction::Store
9286+
? NewShuffleCost
9287+
: 0);
9288+
// If not profitable to split - exit.
9289+
if (NewCost >= OriginalCost)
9290+
return false;
9291+
}
9292+
return true;
9293+
}
9294+
91729295
void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
91739296
const EdgeInfo &UserTreeIdx,
91749297
unsigned InterleaveFactor) {
@@ -9271,11 +9394,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
92719394
}
92729395

92739396
// Tries to build split node.
9274-
constexpr unsigned SmallNodeSize = 4;
9275-
auto TrySplitNode = [&, &TTI = *TTI](unsigned SmallNodeSize,
9276-
const InstructionsState &LocalState) {
9277-
if (VL.size() <= SmallNodeSize ||
9278-
TTI.preferAlternateOpcodeVectorization() || !SplitAlternateInstructions)
9397+
auto TrySplitNode = [&](const InstructionsState &LocalState) {
9398+
SmallVector<Value *> Op1, Op2;
9399+
OrdersType ReorderIndices;
9400+
if (!canBuildSplitNode(VL, LocalState, Op1, Op2, ReorderIndices))
92799401
return false;
92809402

92819403
// Any value is used in split node already - just gather.
@@ -9289,105 +9411,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
92899411
}
92909412
return true;
92919413
}
9292-
SmallVector<Value *> Op1, Op2;
9293-
OrdersType ReorderIndices(VL.size(), VL.size());
9294-
SmallBitVector Op1Indices(VL.size());
9295-
for (auto [Idx, V] : enumerate(VL)) {
9296-
auto *I = dyn_cast<Instruction>(V);
9297-
if (!I) {
9298-
Op1.push_back(V);
9299-
Op1Indices.set(Idx);
9300-
continue;
9301-
}
9302-
if ((LocalState.getAltOpcode() != LocalState.getOpcode() &&
9303-
I->getOpcode() == LocalState.getOpcode()) ||
9304-
(LocalState.getAltOpcode() == LocalState.getOpcode() &&
9305-
!isAlternateInstruction(I, LocalState.getMainOp(),
9306-
LocalState.getAltOp(), *TLI))) {
9307-
Op1.push_back(V);
9308-
Op1Indices.set(Idx);
9309-
continue;
9310-
}
9311-
Op2.push_back(V);
9312-
}
9313-
Type *ScalarTy = getValueType(VL.front());
9314-
VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
9315-
unsigned Opcode0 = LocalState.getOpcode();
9316-
unsigned Opcode1 = LocalState.getAltOpcode();
9317-
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
9318-
// Enable split node, only if all nodes do not form legal alternate
9319-
// instruction (like X86 addsub).
9320-
SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
9321-
SmallPtrSet<Value *, 4> UOp2(llvm::from_range, Op2);
9322-
if (UOp1.size() <= 1 || UOp2.size() <= 1 ||
9323-
TTI.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask) ||
9324-
!hasFullVectorsOrPowerOf2(TTI, Op1.front()->getType(), Op1.size()) ||
9325-
!hasFullVectorsOrPowerOf2(TTI, Op2.front()->getType(), Op2.size()))
9326-
return false;
9327-
// Enable split node, only if all nodes are power-of-2/full registers.
9328-
unsigned Op1Cnt = 0, Op2Cnt = Op1.size();
9329-
for (unsigned Idx : seq<unsigned>(VL.size())) {
9330-
if (Op1Indices.test(Idx)) {
9331-
ReorderIndices[Op1Cnt] = Idx;
9332-
++Op1Cnt;
9333-
} else {
9334-
ReorderIndices[Op2Cnt] = Idx;
9335-
++Op2Cnt;
9336-
}
9337-
}
9338-
if (isIdentityOrder(ReorderIndices))
9339-
ReorderIndices.clear();
9340-
SmallVector<int> Mask;
9341-
if (!ReorderIndices.empty())
9342-
inversePermutation(ReorderIndices, Mask);
9343-
unsigned NumParts = TTI.getNumberOfParts(VecTy);
9344-
VectorType *Op1VecTy = getWidenedType(ScalarTy, Op1.size());
9345-
VectorType *Op2VecTy = getWidenedType(ScalarTy, Op2.size());
9346-
// Check non-profitable single register ops, which better to be represented
9347-
// as alternate ops.
9348-
if (NumParts >= VL.size())
9349-
return false;
9350-
if ((LocalState.getMainOp()->isBinaryOp() &&
9351-
LocalState.getAltOp()->isBinaryOp() &&
9352-
(LocalState.isShiftOp() || LocalState.isBitwiseLogicOp() ||
9353-
LocalState.isAddSubLikeOp() || LocalState.isMulDivLikeOp())) ||
9354-
(LocalState.getMainOp()->isCast() && LocalState.getAltOp()->isCast()) ||
9355-
(LocalState.getMainOp()->isUnaryOp() &&
9356-
LocalState.getAltOp()->isUnaryOp())) {
9357-
constexpr TTI::TargetCostKind Kind = TTI::TCK_RecipThroughput;
9358-
InstructionCost InsertCost = ::getShuffleCost(
9359-
TTI, TTI::SK_InsertSubvector, VecTy, {}, Kind, Op1.size(), Op2VecTy);
9360-
FixedVectorType *SubVecTy =
9361-
getWidenedType(ScalarTy, std::max(Op1.size(), Op2.size()));
9362-
InstructionCost NewShuffleCost =
9363-
::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc, SubVecTy, Mask, Kind);
9364-
if (NumParts <= 1 && (Mask.empty() || InsertCost >= NewShuffleCost))
9365-
return false;
9366-
InstructionCost OriginalVecOpsCost =
9367-
TTI.getArithmeticInstrCost(Opcode0, VecTy, Kind) +
9368-
TTI.getArithmeticInstrCost(Opcode1, VecTy, Kind);
9369-
SmallVector<int> OriginalMask(VL.size(), PoisonMaskElem);
9370-
for (unsigned Idx : seq<unsigned>(VL.size())) {
9371-
if (isa<PoisonValue>(VL[Idx]))
9372-
continue;
9373-
OriginalMask[Idx] = Idx + (Op1Indices.test(Idx) ? 0 : VL.size());
9374-
}
9375-
InstructionCost OriginalCost =
9376-
OriginalVecOpsCost + ::getShuffleCost(TTI, TTI::SK_PermuteTwoSrc,
9377-
VecTy, OriginalMask, Kind);
9378-
InstructionCost NewVecOpsCost =
9379-
TTI.getArithmeticInstrCost(Opcode0, Op1VecTy, Kind) +
9380-
TTI.getArithmeticInstrCost(Opcode1, Op2VecTy, Kind);
9381-
InstructionCost NewCost =
9382-
NewVecOpsCost + InsertCost +
9383-
(!VectorizableTree.empty() && VectorizableTree.front()->hasState() &&
9384-
VectorizableTree.front()->getOpcode() == Instruction::Store
9385-
? NewShuffleCost
9386-
: 0);
9387-
// If not profitable to split - exit.
9388-
if (NewCost >= OriginalCost)
9389-
return false;
9390-
}
93919414

93929415
SmallVector<Value *> NewVL(VL.size());
93939416
copy(Op1, NewVL.begin());
@@ -9503,8 +9526,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
95039526
if (!S) {
95049527
auto [MainOp, AltOp] = getMainAltOpsNoStateVL(VL);
95059528
// Last chance to try to vectorize alternate node.
9506-
if (MainOp && AltOp &&
9507-
TrySplitNode(SmallNodeSize, InstructionsState(MainOp, AltOp)))
9529+
if (MainOp && AltOp && TrySplitNode(InstructionsState(MainOp, AltOp)))
95089530
return;
95099531
}
95109532
LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O, small shuffle. \n");
@@ -9628,7 +9650,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
96289650
}
96299651

96309652
// FIXME: investigate if there are profitable cases for VL.size() <= 4.
9631-
if (S.isAltShuffle() && TrySplitNode(SmallNodeSize, S))
9653+
if (S.isAltShuffle() && TrySplitNode(S))
96329654
return;
96339655

96349656
// Check that every instruction appears once in this bundle.
@@ -9663,8 +9685,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
96639685
if (!BundlePtr || (*BundlePtr && !*BundlePtr.value())) {
96649686
LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n");
96659687
// Last chance to try to vectorize alternate node.
9666-
if (S.isAltShuffle() && ReuseShuffleIndices.empty() &&
9667-
TrySplitNode(SmallNodeSize, S))
9688+
if (S.isAltShuffle() && ReuseShuffleIndices.empty() && TrySplitNode(S))
96689689
return;
96699690
auto Invalid = ScheduleBundle::invalid();
96709691
newTreeEntry(VL, Invalid /*not vectorized*/, S, UserTreeIdx,

0 commit comments

Comments
 (0)