Skip to content

Commit 26a9f3f

Browse files
committed
[SLP][NFC]Cleanup getSameOpcode, return InstructionsState::invalid() for non-valid inputs
Just a cleanup and related changes
1 parent 8a7a7b5 commit 26a9f3f

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,7 @@ struct InstructionsState {
832832
InstructionsState() = delete;
833833
InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp)
834834
: OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
835+
static InstructionsState invalid() { return {nullptr, nullptr, nullptr}; }
835836
};
836837

837838
} // end anonymous namespace
@@ -891,20 +892,19 @@ static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI,
891892
/// could be vectorized even if its structure is diverse.
892893
static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
893894
const TargetLibraryInfo &TLI) {
894-
constexpr unsigned BaseIndex = 0;
895895
// Make sure these are all Instructions.
896-
if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); }))
897-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
896+
if (!all_of(VL, IsaPred<Instruction>))
897+
return InstructionsState::invalid();
898898

899-
bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
900-
bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
901-
bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
899+
Value *V = VL.front();
900+
bool IsCastOp = isa<CastInst>(V);
901+
bool IsBinOp = isa<BinaryOperator>(V);
902+
bool IsCmpOp = isa<CmpInst>(V);
902903
CmpInst::Predicate BasePred =
903-
IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
904-
: CmpInst::BAD_ICMP_PREDICATE;
905-
unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
904+
IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
905+
unsigned Opcode = cast<Instruction>(V)->getOpcode();
906906
unsigned AltOpcode = Opcode;
907-
unsigned AltIndex = BaseIndex;
907+
unsigned AltIndex = 0;
908908

909909
bool SwappedPredsCompatible = [&]() {
910910
if (!IsCmpOp)
@@ -931,14 +931,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
931931
}();
932932
// Check for one alternate opcode from another BinaryOperator.
933933
// TODO - generalize to support all operators (types, calls etc.).
934-
auto *IBase = cast<Instruction>(VL[BaseIndex]);
934+
auto *IBase = cast<Instruction>(V);
935935
Intrinsic::ID BaseID = 0;
936936
SmallVector<VFInfo> BaseMappings;
937937
if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
938938
BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
939939
BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
940940
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
941-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
941+
return InstructionsState::invalid();
942942
}
943943
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
944944
auto *I = cast<Instruction>(VL[Cnt]);
@@ -970,7 +970,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
970970
}
971971
}
972972
} else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
973-
auto *BaseInst = cast<CmpInst>(VL[BaseIndex]);
973+
auto *BaseInst = cast<CmpInst>(V);
974974
Type *Ty0 = BaseInst->getOperand(0)->getType();
975975
Type *Ty1 = Inst->getOperand(0)->getType();
976976
if (Ty0 == Ty1) {
@@ -988,7 +988,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
988988
if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
989989
continue;
990990
auto *AltInst = cast<CmpInst>(VL[AltIndex]);
991-
if (AltIndex != BaseIndex) {
991+
if (AltIndex) {
992992
if (isCmpSameOrSwapped(AltInst, Inst, TLI))
993993
continue;
994994
} else if (BasePred != CurrentPred) {
@@ -1007,27 +1007,28 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
10071007
if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
10081008
if (Gep->getNumOperands() != 2 ||
10091009
Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
1010-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1010+
return InstructionsState::invalid();
10111011
} else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
10121012
if (!isVectorLikeInstWithConstOps(EI))
1013-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1013+
return InstructionsState::invalid();
10141014
} else if (auto *LI = dyn_cast<LoadInst>(I)) {
10151015
auto *BaseLI = cast<LoadInst>(IBase);
10161016
if (!LI->isSimple() || !BaseLI->isSimple())
1017-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1017+
return InstructionsState::invalid();
10181018
} else if (auto *Call = dyn_cast<CallInst>(I)) {
10191019
auto *CallBase = cast<CallInst>(IBase);
10201020
if (Call->getCalledFunction() != CallBase->getCalledFunction())
1021-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1022-
if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() ||
1023-
!std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1024-
Call->op_begin() + Call->getBundleOperandsEndIndex(),
1025-
CallBase->op_begin() +
1026-
CallBase->getBundleOperandsStartIndex())))
1027-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1021+
return InstructionsState::invalid();
1022+
if (Call->hasOperandBundles() &&
1023+
(!CallBase->hasOperandBundles() ||
1024+
!std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1025+
Call->op_begin() + Call->getBundleOperandsEndIndex(),
1026+
CallBase->op_begin() +
1027+
CallBase->getBundleOperandsStartIndex())))
1028+
return InstructionsState::invalid();
10281029
Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI);
10291030
if (ID != BaseID)
1030-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1031+
return InstructionsState::invalid();
10311032
if (!ID) {
10321033
SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
10331034
if (Mappings.size() != BaseMappings.size() ||
@@ -1037,15 +1038,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
10371038
Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
10381039
Mappings.front().Shape.Parameters !=
10391040
BaseMappings.front().Shape.Parameters)
1040-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1041+
return InstructionsState::invalid();
10411042
}
10421043
}
10431044
continue;
10441045
}
1045-
return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1046+
return InstructionsState::invalid();
10461047
}
10471048

1048-
return InstructionsState(VL[BaseIndex], cast<Instruction>(VL[BaseIndex]),
1049+
return InstructionsState(V, cast<Instruction>(V),
10491050
cast<Instruction>(VL[AltIndex]));
10501051
}
10511052

@@ -8019,7 +8020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
80198020
}
80208021

80218022
// Don't handle vectors.
8022-
if (!SLPReVec && getValueType(S.OpValue)->isVectorTy()) {
8023+
if (!SLPReVec && getValueType(VL.front())->isVectorTy()) {
80238024
LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
80248025
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
80258026
return;
@@ -8088,7 +8089,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
80888089
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
80898090
bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL);
80908091
bool AreScatterAllGEPSameBlock =
8091-
(IsScatterVectorizeUserTE && S.OpValue->getType()->isPointerTy() &&
8092+
(IsScatterVectorizeUserTE && VL.front()->getType()->isPointerTy() &&
80928093
VL.size() > 2 &&
80938094
all_of(VL,
80948095
[&BB](Value *V) {
@@ -8104,7 +8105,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
81048105
SortedIndices));
81058106
bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock;
81068107
if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) ||
8107-
(isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(
8108+
(isa_and_present<InsertElementInst, ExtractValueInst, ExtractElementInst>(
81088109
S.OpValue) &&
81098110
!all_of(VL, isVectorLikeInstWithConstOps)) ||
81108111
NotProfitableForVectorization(VL)) {
@@ -8161,7 +8162,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
81618162
// Special processing for sorted pointers for ScatterVectorize node with
81628163
// constant indeces only.
81638164
if (!AreAllSameBlock && AreScatterAllGEPSameBlock) {
8164-
assert(S.OpValue->getType()->isPointerTy() &&
8165+
assert(VL.front()->getType()->isPointerTy() &&
81658166
count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
81668167
"Expected pointers only.");
81678168
// Reset S to make it GetElementPtr kind of node.

0 commit comments

Comments
 (0)