@@ -916,24 +916,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
916
916
if (It == VL.end())
917
917
return InstructionsState::invalid();
918
918
919
- Value *V = *It;
919
+ Instruction *MainOp = cast<Instruction>( *It) ;
920
920
unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
921
- if ((VL.size() > 2 && !isa<PHINode>(V ) && InstCnt < VL.size() / 2) ||
921
+ if ((VL.size() > 2 && !isa<PHINode>(MainOp ) && InstCnt < VL.size() / 2) ||
922
922
(VL.size() == 2 && InstCnt < 2))
923
923
return InstructionsState::invalid();
924
924
925
- bool IsCastOp = isa<CastInst>(V);
926
- bool IsBinOp = isa<BinaryOperator>(V);
927
- bool IsCmpOp = isa<CmpInst>(V);
928
- CmpInst::Predicate BasePred =
929
- IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
930
- unsigned Opcode = cast<Instruction>(V)->getOpcode();
925
+ bool IsCastOp = isa<CastInst>(MainOp);
926
+ bool IsBinOp = isa<BinaryOperator>(MainOp);
927
+ bool IsCmpOp = isa<CmpInst>(MainOp);
928
+ CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
929
+ : CmpInst::BAD_ICMP_PREDICATE;
930
+ Instruction *AltOp = MainOp;
931
+ unsigned Opcode = MainOp->getOpcode();
931
932
unsigned AltOpcode = Opcode;
932
- unsigned AltIndex = std::distance(VL.begin(), It);
933
933
934
- bool SwappedPredsCompatible = [&]() {
935
- if (!IsCmpOp)
936
- return false;
934
+ bool SwappedPredsCompatible = IsCmpOp && [&]() {
937
935
SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
938
936
UniquePreds.insert(BasePred);
939
937
UniqueNonSwappedPreds.insert(BasePred);
@@ -956,18 +954,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
956
954
}();
957
955
// Check for one alternate opcode from another BinaryOperator.
958
956
// TODO - generalize to support all operators (types, calls etc.).
959
- auto *IBase = cast<Instruction>(V);
960
957
Intrinsic::ID BaseID = 0;
961
958
SmallVector<VFInfo> BaseMappings;
962
- if (auto *CallBase = dyn_cast<CallInst>(IBase )) {
959
+ if (auto *CallBase = dyn_cast<CallInst>(MainOp )) {
963
960
BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
964
961
BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
965
962
if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
966
963
return InstructionsState::invalid();
967
964
}
968
965
bool AnyPoison = InstCnt != VL.size();
969
- for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
970
- auto *I = dyn_cast<Instruction>(VL[Cnt]);
966
+ // Skip MainOp.
967
+ for (Value *V : iterator_range(It + 1, VL.end())) {
968
+ auto *I = dyn_cast<Instruction>(V);
971
969
if (!I)
972
970
continue;
973
971
@@ -983,11 +981,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
983
981
if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
984
982
isValidForAlternation(Opcode)) {
985
983
AltOpcode = InstOpcode;
986
- AltIndex = Cnt ;
984
+ AltOp = I ;
987
985
continue;
988
986
}
989
987
} else if (IsCastOp && isa<CastInst>(I)) {
990
- Value *Op0 = IBase ->getOperand(0);
988
+ Value *Op0 = MainOp ->getOperand(0);
991
989
Type *Ty0 = Op0->getType();
992
990
Value *Op1 = I->getOperand(0);
993
991
Type *Ty1 = Op1->getType();
@@ -999,12 +997,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
999
997
isValidForAlternation(InstOpcode) &&
1000
998
"Cast isn't safe for alternation, logic needs to be updated!");
1001
999
AltOpcode = InstOpcode;
1002
- AltIndex = Cnt ;
1000
+ AltOp = I ;
1003
1001
continue;
1004
1002
}
1005
1003
}
1006
- } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt] ); Inst && IsCmpOp) {
1007
- auto *BaseInst = cast<CmpInst>(V );
1004
+ } else if (auto *Inst = dyn_cast<CmpInst>(I ); Inst && IsCmpOp) {
1005
+ auto *BaseInst = cast<CmpInst>(MainOp );
1008
1006
Type *Ty0 = BaseInst->getOperand(0)->getType();
1009
1007
Type *Ty1 = Inst->getOperand(0)->getType();
1010
1008
if (Ty0 == Ty1) {
@@ -1018,21 +1016,21 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1018
1016
CmpInst::Predicate SwappedCurrentPred =
1019
1017
CmpInst::getSwappedPredicate(CurrentPred);
1020
1018
1021
- if ((E == 2 || SwappedPredsCompatible) &&
1019
+ if ((VL.size() == 2 || SwappedPredsCompatible) &&
1022
1020
(BasePred == CurrentPred || BasePred == SwappedCurrentPred))
1023
1021
continue;
1024
1022
1025
1023
if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
1026
1024
continue;
1027
- auto *AltInst = cast<CmpInst>(VL[AltIndex] );
1028
- if (AltIndex ) {
1025
+ auto *AltInst = cast<CmpInst>(AltOp );
1026
+ if (MainOp != AltOp ) {
1029
1027
if (isCmpSameOrSwapped(AltInst, Inst, TLI))
1030
1028
continue;
1031
1029
} else if (BasePred != CurrentPred) {
1032
1030
assert(
1033
1031
isValidForAlternation(InstOpcode) &&
1034
1032
"CmpInst isn't safe for alternation, logic needs to be updated!");
1035
- AltIndex = Cnt ;
1033
+ AltOp = I ;
1036
1034
continue;
1037
1035
}
1038
1036
CmpInst::Predicate AltPred = AltInst->getPredicate();
@@ -1046,17 +1044,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1046
1044
"CastInst.");
1047
1045
if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
1048
1046
if (Gep->getNumOperands() != 2 ||
1049
- Gep->getOperand(0)->getType() != IBase ->getOperand(0)->getType())
1047
+ Gep->getOperand(0)->getType() != MainOp ->getOperand(0)->getType())
1050
1048
return InstructionsState::invalid();
1051
1049
} else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
1052
1050
if (!isVectorLikeInstWithConstOps(EI))
1053
1051
return InstructionsState::invalid();
1054
1052
} else if (auto *LI = dyn_cast<LoadInst>(I)) {
1055
- auto *BaseLI = cast<LoadInst>(IBase );
1053
+ auto *BaseLI = cast<LoadInst>(MainOp );
1056
1054
if (!LI->isSimple() || !BaseLI->isSimple())
1057
1055
return InstructionsState::invalid();
1058
1056
} else if (auto *Call = dyn_cast<CallInst>(I)) {
1059
- auto *CallBase = cast<CallInst>(IBase );
1057
+ auto *CallBase = cast<CallInst>(MainOp );
1060
1058
if (Call->getCalledFunction() != CallBase->getCalledFunction())
1061
1059
return InstructionsState::invalid();
1062
1060
if (Call->hasOperandBundles() &&
@@ -1086,8 +1084,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
1086
1084
return InstructionsState::invalid();
1087
1085
}
1088
1086
1089
- return InstructionsState(cast<Instruction>(V),
1090
- cast<Instruction>(VL[AltIndex]));
1087
+ return InstructionsState(MainOp, AltOp);
1091
1088
}
1092
1089
1093
1090
/// \returns true if all of the values in \p VL have the same type or false
0 commit comments