27
27
#include " llvm/IR/Dominators.h"
28
28
#include " llvm/IR/Function.h"
29
29
#include " llvm/IR/IRBuilder.h"
30
+ #include " llvm/IR/IntrinsicInst.h"
30
31
#include " llvm/IR/PatternMatch.h"
31
32
#include " llvm/Support/CommandLine.h"
32
33
#include " llvm/Transforms/Utils/Local.h"
@@ -1021,24 +1022,17 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1021
1022
// / inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
1022
1023
// / by insertelement.
1023
1024
bool VectorCombine::scalarizeBinopOrCmp (Instruction &I) {
1024
- CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1025
- Value *Ins0, *Ins1;
1026
- if (!match (&I, m_BinOp (m_Value (Ins0), m_Value (Ins1))) &&
1027
- !match (&I, m_Cmp (Pred, m_Value (Ins0), m_Value (Ins1)))) {
1028
- // TODO: Allow unary and ternary intrinsics
1029
- // TODO: Allow intrinsics with different argument types
1030
- // TODO: Allow intrinsics with scalar arguments
1031
- if (auto *II = dyn_cast<IntrinsicInst>(&I);
1032
- II && II->arg_size () == 2 &&
1033
- isTriviallyVectorizable (II->getIntrinsicID ()) &&
1034
- all_of (II->args (),
1035
- [&II](Value *Arg) { return Arg->getType () == II->getType (); })) {
1036
- Ins0 = II->getArgOperand (0 );
1037
- Ins1 = II->getArgOperand (1 );
1038
- } else {
1025
+ // TODO: Allow unary operators
1026
+ if (!isa<BinaryOperator, CmpInst, IntrinsicInst>(I))
1027
+ return false ;
1028
+
1029
+ // TODO: Allow intrinsics with different argument types
1030
+ // TODO: Allow intrinsics with scalar arguments
1031
+ if (auto *II = dyn_cast<IntrinsicInst>(&I))
1032
+ if (!isTriviallyVectorizable (II->getIntrinsicID ()) ||
1033
+ !all_of (II->args (),
1034
+ [&II](Value *Arg) { return Arg->getType () == II->getType (); }))
1039
1035
return false ;
1040
- }
1041
- }
1042
1036
1043
1037
// Do not convert the vector condition of a vector select into a scalar
1044
1038
// condition. That may cause problems for codegen because of differences in
@@ -1055,36 +1049,43 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
1055
1049
// vec_op (inselt VecC0, V0, Index), VecC1
1056
1050
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
1057
1051
// TODO: Deal with mismatched index constants and variable indexes?
1058
- Constant *VecC0 = nullptr , *VecC1 = nullptr ;
1059
- Value *V0 = nullptr , *V1 = nullptr ;
1060
- uint64_t Index0 = 0 , Index1 = 0 ;
1061
- if (!match (Ins0, m_InsertElt (m_Constant (VecC0), m_Value (V0),
1062
- m_ConstantInt (Index0))) &&
1063
- !match (Ins0, m_Constant (VecC0)))
1064
- return false ;
1065
- if (!match (Ins1, m_InsertElt (m_Constant (VecC1), m_Value (V1),
1066
- m_ConstantInt (Index1))) &&
1067
- !match (Ins1, m_Constant (VecC1)))
1068
- return false ;
1069
-
1070
- bool IsConst0 = !V0;
1071
- bool IsConst1 = !V1;
1072
- if (IsConst0 && IsConst1)
1073
- return false ;
1074
- if (!IsConst0 && !IsConst1 && Index0 != Index1)
1075
- return false ;
1052
+ SmallVector<Value *> VecCs, ScalarOps;
1053
+ std::optional<uint64_t > Index;
1054
+
1055
+ auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args ()
1056
+ : I.operand_values ();
1057
+ for (Value *Op : Ops) {
1058
+ Constant *VecC;
1059
+ Value *V;
1060
+ uint64_t InsIdx = 0 ;
1061
+ VectorType *OpTy = cast<VectorType>(Op->getType ());
1062
+ if (match (Op, m_InsertElt (m_Constant (VecC), m_Value (V),
1063
+ m_ConstantInt (InsIdx)))) {
1064
+ // Bail if any inserts are out of bounds.
1065
+ if (OpTy->getElementCount ().getKnownMinValue () <= InsIdx)
1066
+ return false ;
1067
+ // All inserts must have the same index.
1068
+ if (!Index)
1069
+ Index = InsIdx;
1070
+ else if (InsIdx != *Index)
1071
+ return false ;
1072
+ VecCs.push_back (VecC);
1073
+ ScalarOps.push_back (V);
1074
+ } else if (match (Op, m_Constant (VecC))) {
1075
+ VecCs.push_back (VecC);
1076
+ ScalarOps.push_back (nullptr );
1077
+ } else {
1078
+ return false ;
1079
+ }
1080
+ }
1076
1081
1077
- auto *VecTy0 = cast<VectorType>(Ins0->getType ());
1078
- auto *VecTy1 = cast<VectorType>(Ins1->getType ());
1079
- if (VecTy0->getElementCount ().getKnownMinValue () <= Index0 ||
1080
- VecTy1->getElementCount ().getKnownMinValue () <= Index1)
1082
+ // Bail if all operands are constant.
1083
+ if (!Index.has_value ())
1081
1084
return false ;
1082
1085
1083
- uint64_t Index = IsConst0 ? Index1 : Index0;
1084
- Type *ScalarTy = IsConst0 ? V1->getType () : V0->getType ();
1085
- Type *VecTy = I.getType ();
1086
+ VectorType *VecTy = cast<VectorType>(I.getType ());
1087
+ Type *ScalarTy = VecTy->getScalarType ();
1086
1088
assert (VecTy->isVectorTy () &&
1087
- (IsConst0 || IsConst1 || V0->getType () == V1->getType ()) &&
1088
1089
(ScalarTy->isIntegerTy () || ScalarTy->isFloatingPointTy () ||
1089
1090
ScalarTy->isPointerTy ()) &&
1090
1091
" Unexpected types for insert element into binop or cmp" );
@@ -1114,17 +1115,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
1114
1115
1115
1116
// Get cost estimate for the insert element. This cost will factor into
1116
1117
// both sequences.
1117
- InstructionCost InsertCost = TTI.getVectorInstrCost (
1118
- Instruction::InsertElement, VecTy, CostKind, Index);
1119
- InstructionCost InsertCostV0 = TTI.getVectorInstrCost (
1120
- Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
1121
- InstructionCost InsertCostV1 = TTI.getVectorInstrCost (
1122
- Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
1123
- InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
1124
- (IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
1125
- InstructionCost NewCost = ScalarOpCost + InsertCost +
1126
- (IsConst0 ? 0 : !Ins0->hasOneUse () * InsertCostV0) +
1127
- (IsConst1 ? 0 : !Ins1->hasOneUse () * InsertCostV1);
1118
+ InstructionCost OldCost = VectorOpCost;
1119
+ InstructionCost NewCost =
1120
+ ScalarOpCost + TTI.getVectorInstrCost (Instruction::InsertElement, VecTy,
1121
+ CostKind, *Index);
1122
+ for (auto [Op, VecC, Scalar] : zip (Ops, VecCs, ScalarOps)) {
1123
+ if (!Scalar)
1124
+ continue ;
1125
+ InstructionCost InsertCost = TTI.getVectorInstrCost (
1126
+ Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1127
+ OldCost += InsertCost;
1128
+ NewCost += !Op->hasOneUse () * InsertCost;
1129
+ }
1128
1130
1129
1131
// We want to scalarize unless the vector variant actually has lower cost.
1130
1132
if (OldCost < NewCost || !NewCost.isValid ())
@@ -1140,19 +1142,20 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
1140
1142
++NumScalarIntrinsic;
1141
1143
1142
1144
// For constant cases, extract the scalar element, this should constant fold.
1143
- if (IsConst0 )
1144
- V0 = ConstantExpr::getExtractElement (VecC0, Builder. getInt64 (Index));
1145
- if (IsConst1)
1146
- V1 = ConstantExpr::getExtractElement (VecC1 , Builder.getInt64 (Index));
1145
+ for ( auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs) )
1146
+ if (!Scalar)
1147
+ ScalarOps[OpIdx] = ConstantExpr::getExtractElement (
1148
+ cast<Constant>(VecC) , Builder.getInt64 (* Index));
1147
1149
1148
1150
Value *Scalar;
1149
- if (isa <CmpInst>(I))
1150
- Scalar = Builder.CreateCmp (Pred, V0, V1 );
1151
+ if (auto *CI = dyn_cast <CmpInst>(& I))
1152
+ Scalar = Builder.CreateCmp (CI-> getPredicate (), ScalarOps[ 0 ], ScalarOps[ 1 ] );
1151
1153
else if (isa<BinaryOperator>(I))
1152
- Scalar = Builder.CreateBinOp ((Instruction::BinaryOps)Opcode, V0, V1);
1154
+ Scalar = Builder.CreateBinOp ((Instruction::BinaryOps)Opcode, ScalarOps[0 ],
1155
+ ScalarOps[1 ]);
1153
1156
else
1154
1157
Scalar = Builder.CreateIntrinsic (
1155
- ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID (), {V0, V1} );
1158
+ ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID (), ScalarOps );
1156
1159
1157
1160
Scalar->setName (I.getName () + " .scalar" );
1158
1161
@@ -1163,14 +1166,14 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
1163
1166
1164
1167
// Fold the vector constants in the original vectors into a new base vector.
1165
1168
Value *NewVecC;
1166
- if (isa <CmpInst>(I))
1167
- NewVecC = Builder.CreateCmp (Pred, VecC0, VecC1 );
1169
+ if (auto *CI = dyn_cast <CmpInst>(& I))
1170
+ NewVecC = Builder.CreateCmp (CI-> getPredicate (), VecCs[ 0 ], VecCs[ 1 ] );
1168
1171
else if (isa<BinaryOperator>(I))
1169
- NewVecC = Builder.CreateBinOp ((Instruction::BinaryOps) Opcode, VecC0, VecC1 );
1172
+ NewVecC = Builder.CreateNAryOp ( Opcode, VecCs );
1170
1173
else
1171
1174
NewVecC = Builder.CreateIntrinsic (
1172
- VecTy, cast<IntrinsicInst>(I).getIntrinsicID (), {VecC0, VecC1} );
1173
- Value *Insert = Builder.CreateInsertElement (NewVecC, Scalar, Index);
1175
+ VecTy, cast<IntrinsicInst>(I).getIntrinsicID (), VecCs );
1176
+ Value *Insert = Builder.CreateInsertElement (NewVecC, Scalar, * Index);
1174
1177
replaceValue (I, *Insert);
1175
1178
return true ;
1176
1179
}
0 commit comments