Skip to content

Commit 7202f47

Browse files
committed
[SLP] separate min/max matching from its instruction-level implementation; NFC
The motivation is to handle integer min/max reductions independently of whether they are in the current cmp+sel form or the planned intrinsic form. We assumed that min/max included a select instruction, but we can decouple that implementation detail by checking the instructions themselves rather than relying on the recurrence (reduction) type.
1 parent 5d03745 commit 7202f47

File tree

1 file changed

+33
-45
lines changed

1 file changed

+33
-45
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6488,8 +6488,7 @@ class HorizontalReduction {
64886488
// in this case.
64896489
// Do not perform analysis of remaining operands of ParentStackElem.first
64906490
// instruction, this whole instruction is an extra argument.
6491-
RecurKind ParentRdxKind = getRdxKind(ParentStackElem.first);
6492-
ParentStackElem.second = getNumberOfOperands(ParentRdxKind);
6491+
ParentStackElem.second = getNumberOfOperands(ParentStackElem.first);
64936492
} else {
64946493
// We ran into something like:
64956494
// ParentStackElem.first += ... + ExtraArg + ...
@@ -6590,7 +6589,6 @@ class HorizontalReduction {
65906589
if (match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value())))
65916590
return RecurKind::FMin;
65926591

6593-
65946592
if (auto *Select = dyn_cast<SelectInst>(I)) {
65956593
// These would also match llvm.{u,s}{min,max} intrinsic call
65966594
// if were not guarded by the SelectInst check above.
@@ -6660,64 +6658,54 @@ class HorizontalReduction {
66606658
return RecurKind::None;
66616659
}
66626660

6663-
/// Return true if this operation is a cmp+select idiom.
6664-
static bool isCmpSel(RecurKind Kind) {
6665-
return RecurrenceDescriptor::isIntMinMaxRecurrenceKind(Kind);
6666-
}
6667-
66686661
/// Get the index of the first operand.
6669-
static unsigned getFirstOperandIndex(RecurKind Kind) {
6670-
// We allow calling this before 'Kind' is set, so handle that specially.
6671-
if (Kind == RecurKind::None)
6672-
return 0;
6673-
return isCmpSel(Kind) ? 1 : 0;
6662+
static unsigned getFirstOperandIndex(Instruction *I) {
6663+
return isa<SelectInst>(I) ? 1 : 0;
66746664
}
66756665

66766666
/// Total number of operands in the reduction operation.
6677-
static unsigned getNumberOfOperands(RecurKind Kind) {
6678-
return isCmpSel(Kind) ? 3 : 2;
6667+
static unsigned getNumberOfOperands(Instruction *I) {
6668+
return isa<SelectInst>(I) ? 3 : 2;
66796669
}
66806670

66816671
/// Checks if the instruction is in basic block \p BB.
66826672
/// For a min/max reduction check that both compare and select are in \p BB.
6683-
static bool hasSameParent(RecurKind Kind, Instruction *I, BasicBlock *BB,
6684-
bool IsRedOp) {
6685-
if (IsRedOp && isCmpSel(Kind)) {
6686-
auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition());
6687-
return I->getParent() == BB && Cmp->getParent() == BB;
6673+
static bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) {
6674+
auto *Sel = dyn_cast<SelectInst>(I);
6675+
if (IsRedOp && Sel) {
6676+
auto *Cmp = cast<Instruction>(Sel->getCondition());
6677+
return Sel->getParent() == BB && Cmp->getParent() == BB;
66886678
}
66896679
return I->getParent() == BB;
66906680
}
66916681

66926682
/// Expected number of uses for reduction operations/reduced values.
6693-
static bool hasRequiredNumberOfUses(RecurKind Kind, Instruction *I,
6694-
bool IsReductionOp) {
6695-
assert(Kind != RecurKind::None && "Reduction type not set");
6683+
static bool hasRequiredNumberOfUses(bool MatchCmpSel, Instruction *I) {
66966684
// SelectInst must be used twice while the condition op must have single
66976685
// use only.
6698-
if (isCmpSel(Kind))
6699-
return I->hasNUses(2) &&
6700-
(!IsReductionOp ||
6701-
cast<SelectInst>(I)->getCondition()->hasOneUse());
6686+
if (MatchCmpSel) {
6687+
if (auto *Sel = dyn_cast<SelectInst>(I))
6688+
return Sel->hasNUses(2) && Sel->getCondition()->hasOneUse();
6689+
return I->hasNUses(2);
6690+
}
67026691

67036692
// Arithmetic reduction operation must be used once only.
67046693
return I->hasOneUse();
67056694
}
67066695

67076696
/// Initializes the list of reduction operations.
6708-
void initReductionOps(RecurKind Kind) {
6709-
if (isCmpSel(Kind))
6697+
void initReductionOps(Instruction *I) {
6698+
if (isa<SelectInst>(I))
67106699
ReductionOps.assign(2, ReductionOpsType());
67116700
else
67126701
ReductionOps.assign(1, ReductionOpsType());
67136702
}
67146703

67156704
/// Add all reduction operations for the reduction instruction \p I.
6716-
void addReductionOps(RecurKind Kind, Instruction *I) {
6717-
assert(Kind != RecurKind::None && "Expected reduction operation.");
6718-
if (isCmpSel(Kind)) {
6719-
ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
6720-
ReductionOps[1].emplace_back(I);
6705+
void addReductionOps(Instruction *I) {
6706+
if (auto *Sel = dyn_cast<SelectInst>(I)) {
6707+
ReductionOps[0].emplace_back(Sel->getCondition());
6708+
ReductionOps[1].emplace_back(Sel);
67216709
} else {
67226710
ReductionOps[0].emplace_back(I);
67236711
}
@@ -6726,12 +6714,12 @@ class HorizontalReduction {
67266714
static Value *getLHS(RecurKind Kind, Instruction *I) {
67276715
if (Kind == RecurKind::None)
67286716
return nullptr;
6729-
return I->getOperand(getFirstOperandIndex(Kind));
6717+
return I->getOperand(getFirstOperandIndex(I));
67306718
}
67316719
static Value *getRHS(RecurKind Kind, Instruction *I) {
67326720
if (Kind == RecurKind::None)
67336721
return nullptr;
6734-
return I->getOperand(getFirstOperandIndex(Kind) + 1);
6722+
return I->getOperand(getFirstOperandIndex(I) + 1);
67356723
}
67366724

67376725
public:
@@ -6783,16 +6771,16 @@ class HorizontalReduction {
67836771
// Post order traverse the reduction tree starting at B. We only handle true
67846772
// trees containing only binary operators.
67856773
SmallVector<std::pair<Instruction *, unsigned>, 32> Stack;
6786-
Stack.push_back(std::make_pair(B, getFirstOperandIndex(RdxKind)));
6787-
initReductionOps(RdxKind);
6774+
Stack.push_back(std::make_pair(B, getFirstOperandIndex(B)));
6775+
initReductionOps(B);
67886776
while (!Stack.empty()) {
67896777
Instruction *TreeN = Stack.back().first;
67906778
unsigned EdgeToVisit = Stack.back().second++;
67916779
const RecurKind TreeRdxKind = getRdxKind(TreeN);
67926780
bool IsReducedValue = TreeRdxKind != RdxKind;
67936781

67946782
// Postorder visit.
6795-
if (IsReducedValue || EdgeToVisit == getNumberOfOperands(TreeRdxKind)) {
6783+
if (IsReducedValue || EdgeToVisit == getNumberOfOperands(TreeN)) {
67966784
if (IsReducedValue)
67976785
ReducedVals.push_back(TreeN);
67986786
else {
@@ -6810,7 +6798,7 @@ class HorizontalReduction {
68106798
markExtraArg(Stack[Stack.size() - 2], TreeN);
68116799
ExtraArgs.erase(TreeN);
68126800
} else
6813-
addReductionOps(RdxKind, TreeN);
6801+
addReductionOps(TreeN);
68146802
}
68156803
// Retract.
68166804
Stack.pop_back();
@@ -6836,8 +6824,8 @@ class HorizontalReduction {
68366824
// ultimate reduction.
68376825
const bool IsRdxInst = EdgeRdxKind == RdxKind;
68386826
if (EdgeInst != Phi && EdgeInst != B &&
6839-
hasSameParent(RdxKind, EdgeInst, B->getParent(), IsRdxInst) &&
6840-
hasRequiredNumberOfUses(RdxKind, EdgeInst, IsRdxInst) &&
6827+
hasSameParent(EdgeInst, B->getParent(), IsRdxInst) &&
6828+
hasRequiredNumberOfUses(isa<SelectInst>(B), EdgeInst) &&
68416829
(!LeafOpcode || LeafOpcode == EdgeInst->getOpcode() || IsRdxInst)) {
68426830
if (IsRdxInst) {
68436831
// We need to be able to reassociate the reduction operations.
@@ -6850,7 +6838,7 @@ class HorizontalReduction {
68506838
LeafOpcode = EdgeInst->getOpcode();
68516839
}
68526840
Stack.push_back(
6853-
std::make_pair(EdgeInst, getFirstOperandIndex(EdgeRdxKind)));
6841+
std::make_pair(EdgeInst, getFirstOperandIndex(EdgeInst)));
68546842
continue;
68556843
}
68566844
// I is an extra argument for TreeN (its parent operation).
@@ -6997,7 +6985,7 @@ class HorizontalReduction {
69976985
// Emit a reduction. If the root is a select (min/max idiom), the insert
69986986
// point is the compare condition of that select.
69996987
Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
7000-
if (isCmpSel(RdxKind))
6988+
if (isa<SelectInst>(RdxRootInst))
70016989
Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
70026990
else
70036991
Builder.SetInsertPoint(RdxRootInst);
@@ -7039,7 +7027,7 @@ class HorizontalReduction {
70397027
// select, we also have to RAUW for the compare instruction feeding the
70407028
// reduction root. That's because the original compare may have extra uses
70417029
// besides the final select of the reduction.
7042-
if (isCmpSel(RdxKind)) {
7030+
if (isa<SelectInst>(ReductionRoot)) {
70437031
if (auto *VecSelect = dyn_cast<SelectInst>(VectorizedTree)) {
70447032
Instruction *ScalarCmp =
70457033
getCmpForMinMaxReduction(cast<Instruction>(ReductionRoot));

0 commit comments

Comments
 (0)