@@ -832,6 +832,7 @@ struct InstructionsState {
832
832
InstructionsState() = delete;
833
833
InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp)
834
834
: OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
835
+ static InstructionsState invalid() { return {nullptr, nullptr, nullptr}; }
835
836
};
836
837
837
838
} // end anonymous namespace
@@ -891,20 +892,19 @@ static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI,
891
892
/// could be vectorized even if its structure is diverse.
892
893
static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
893
894
const TargetLibraryInfo &TLI) {
894
- constexpr unsigned BaseIndex = 0;
895
895
// Make sure these are all Instructions.
896
- if (llvm::any_of (VL, [](Value *V) { return !isa <Instruction>(V); } ))
897
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
896
+ if (!all_of (VL, IsaPred <Instruction>))
897
+ return InstructionsState::invalid( );
898
898
899
- bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
900
- bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
901
- bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
899
+ Value *V = VL.front();
900
+ bool IsCastOp = isa<CastInst>(V);
901
+ bool IsBinOp = isa<BinaryOperator>(V);
902
+ bool IsCmpOp = isa<CmpInst>(V);
902
903
CmpInst::Predicate BasePred =
903
- IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
904
- : CmpInst::BAD_ICMP_PREDICATE;
905
- unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
904
+ IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
905
+ unsigned Opcode = cast<Instruction>(V)->getOpcode();
906
906
unsigned AltOpcode = Opcode;
907
- unsigned AltIndex = BaseIndex ;
907
+ unsigned AltIndex = 0 ;
908
908
909
909
bool SwappedPredsCompatible = [&]() {
910
910
if (!IsCmpOp)
@@ -931,14 +931,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
931
931
}();
932
932
// Check for one alternate opcode from another BinaryOperator.
933
933
// TODO - generalize to support all operators (types, calls etc.).
934
- auto *IBase = cast<Instruction>(VL[BaseIndex] );
934
+ auto *IBase = cast<Instruction>(V );
935
935
Intrinsic::ID BaseID = 0;
936
936
SmallVector<VFInfo> BaseMappings;
937
937
if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
938
938
BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
939
939
BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
940
940
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
941
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
941
+ return InstructionsState::invalid( );
942
942
}
943
943
for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
944
944
auto *I = cast<Instruction>(VL[Cnt]);
@@ -970,7 +970,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
970
970
}
971
971
}
972
972
} else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
973
- auto *BaseInst = cast<CmpInst>(VL[BaseIndex] );
973
+ auto *BaseInst = cast<CmpInst>(V );
974
974
Type *Ty0 = BaseInst->getOperand(0)->getType();
975
975
Type *Ty1 = Inst->getOperand(0)->getType();
976
976
if (Ty0 == Ty1) {
@@ -988,7 +988,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
988
988
if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
989
989
continue;
990
990
auto *AltInst = cast<CmpInst>(VL[AltIndex]);
991
- if (AltIndex != BaseIndex ) {
991
+ if (AltIndex) {
992
992
if (isCmpSameOrSwapped(AltInst, Inst, TLI))
993
993
continue;
994
994
} else if (BasePred != CurrentPred) {
@@ -1007,27 +1007,28 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1007
1007
if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
1008
1008
if (Gep->getNumOperands() != 2 ||
1009
1009
Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
1010
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1010
+ return InstructionsState::invalid( );
1011
1011
} else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
1012
1012
if (!isVectorLikeInstWithConstOps(EI))
1013
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1013
+ return InstructionsState::invalid( );
1014
1014
} else if (auto *LI = dyn_cast<LoadInst>(I)) {
1015
1015
auto *BaseLI = cast<LoadInst>(IBase);
1016
1016
if (!LI->isSimple() || !BaseLI->isSimple())
1017
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1017
+ return InstructionsState::invalid( );
1018
1018
} else if (auto *Call = dyn_cast<CallInst>(I)) {
1019
1019
auto *CallBase = cast<CallInst>(IBase);
1020
1020
if (Call->getCalledFunction() != CallBase->getCalledFunction())
1021
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1022
- if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() ||
1023
- !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1024
- Call->op_begin() + Call->getBundleOperandsEndIndex(),
1025
- CallBase->op_begin() +
1026
- CallBase->getBundleOperandsStartIndex())))
1027
- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1021
+ return InstructionsState::invalid();
1022
+ if (Call->hasOperandBundles() &&
1023
+ (!CallBase->hasOperandBundles() ||
1024
+ !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1025
+ Call->op_begin() + Call->getBundleOperandsEndIndex(),
1026
+ CallBase->op_begin() +
1027
+ CallBase->getBundleOperandsStartIndex())))
1028
+ return InstructionsState::invalid();
1028
1029
Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI);
1029
1030
if (ID != BaseID)
1030
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1031
+ return InstructionsState::invalid( );
1031
1032
if (!ID) {
1032
1033
SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
1033
1034
if (Mappings.size() != BaseMappings.size() ||
@@ -1037,15 +1038,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1037
1038
Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
1038
1039
Mappings.front().Shape.Parameters !=
1039
1040
BaseMappings.front().Shape.Parameters)
1040
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1041
+ return InstructionsState::invalid( );
1041
1042
}
1042
1043
}
1043
1044
continue;
1044
1045
}
1045
- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1046
+ return InstructionsState::invalid( );
1046
1047
}
1047
1048
1048
- return InstructionsState(VL[BaseIndex] , cast<Instruction>(VL[BaseIndex] ),
1049
+ return InstructionsState(V , cast<Instruction>(V ),
1049
1050
cast<Instruction>(VL[AltIndex]));
1050
1051
}
1051
1052
@@ -8019,7 +8020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8019
8020
}
8020
8021
8021
8022
// Don't handle vectors.
8022
- if (!SLPReVec && getValueType(S.OpValue )->isVectorTy()) {
8023
+ if (!SLPReVec && getValueType(VL.front() )->isVectorTy()) {
8023
8024
LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
8024
8025
newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
8025
8026
return;
@@ -8088,7 +8089,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8088
8089
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
8089
8090
bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL);
8090
8091
bool AreScatterAllGEPSameBlock =
8091
- (IsScatterVectorizeUserTE && S.OpValue ->getType()->isPointerTy() &&
8092
+ (IsScatterVectorizeUserTE && VL.front() ->getType()->isPointerTy() &&
8092
8093
VL.size() > 2 &&
8093
8094
all_of(VL,
8094
8095
[&BB](Value *V) {
@@ -8104,7 +8105,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8104
8105
SortedIndices));
8105
8106
bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock;
8106
8107
if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) ||
8107
- (isa <InsertElementInst, ExtractValueInst, ExtractElementInst>(
8108
+ (isa_and_present <InsertElementInst, ExtractValueInst, ExtractElementInst>(
8108
8109
S.OpValue) &&
8109
8110
!all_of(VL, isVectorLikeInstWithConstOps)) ||
8110
8111
NotProfitableForVectorization(VL)) {
@@ -8161,7 +8162,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
8161
8162
// Special processing for sorted pointers for ScatterVectorize node with
8162
8163
// constant indeces only.
8163
8164
if (!AreAllSameBlock && AreScatterAllGEPSameBlock) {
8164
- assert(S.OpValue ->getType()->isPointerTy() &&
8165
+ assert(VL.front() ->getType()->isPointerTy() &&
8165
8166
count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
8166
8167
"Expected pointers only.");
8167
8168
// Reset S to make it GetElementPtr kind of node.
0 commit comments