Skip to content

Commit 2698f8a

Browse files
committed
WIP support n-ary intrinsics
1 parent e6c1549 commit 2698f8a

File tree

2 files changed

+126
-67
lines changed

2 files changed

+126
-67
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 70 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/IR/Dominators.h"
2828
#include "llvm/IR/Function.h"
2929
#include "llvm/IR/IRBuilder.h"
30+
#include "llvm/IR/IntrinsicInst.h"
3031
#include "llvm/IR/PatternMatch.h"
3132
#include "llvm/Support/CommandLine.h"
3233
#include "llvm/Transforms/Utils/Local.h"
@@ -1021,24 +1022,17 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
10211022
/// inserted scalar operand and convert to scalar binop/cmp/intrinsic followed
10221023
/// by insertelement.
10231024
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
1024-
CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1025-
Value *Ins0, *Ins1;
1026-
if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
1027-
!match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1)))) {
1028-
// TODO: Allow unary and ternary intrinsics
1029-
// TODO: Allow intrinsics with different argument types
1030-
// TODO: Allow intrinsics with scalar arguments
1031-
if (auto *II = dyn_cast<IntrinsicInst>(&I);
1032-
II && II->arg_size() == 2 &&
1033-
isTriviallyVectorizable(II->getIntrinsicID()) &&
1034-
all_of(II->args(),
1035-
[&II](Value *Arg) { return Arg->getType() == II->getType(); })) {
1036-
Ins0 = II->getArgOperand(0);
1037-
Ins1 = II->getArgOperand(1);
1038-
} else {
1025+
// TODO: Allow unary operators
1026+
if (!isa<BinaryOperator, CmpInst, IntrinsicInst>(I))
1027+
return false;
1028+
1029+
// TODO: Allow intrinsics with different argument types
1030+
// TODO: Allow intrinsics with scalar arguments
1031+
if (auto *II = dyn_cast<IntrinsicInst>(&I))
1032+
if (!isTriviallyVectorizable(II->getIntrinsicID()) ||
1033+
!all_of(II->args(),
1034+
[&II](Value *Arg) { return Arg->getType() == II->getType(); }))
10391035
return false;
1040-
}
1041-
}
10421036

10431037
// Do not convert the vector condition of a vector select into a scalar
10441038
// condition. That may cause problems for codegen because of differences in
@@ -1055,36 +1049,43 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
10551049
// vec_op (inselt VecC0, V0, Index), VecC1
10561050
// vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
10571051
// TODO: Deal with mismatched index constants and variable indexes?
1058-
Constant *VecC0 = nullptr, *VecC1 = nullptr;
1059-
Value *V0 = nullptr, *V1 = nullptr;
1060-
uint64_t Index0 = 0, Index1 = 0;
1061-
if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
1062-
m_ConstantInt(Index0))) &&
1063-
!match(Ins0, m_Constant(VecC0)))
1064-
return false;
1065-
if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
1066-
m_ConstantInt(Index1))) &&
1067-
!match(Ins1, m_Constant(VecC1)))
1068-
return false;
1069-
1070-
bool IsConst0 = !V0;
1071-
bool IsConst1 = !V1;
1072-
if (IsConst0 && IsConst1)
1073-
return false;
1074-
if (!IsConst0 && !IsConst1 && Index0 != Index1)
1075-
return false;
1052+
SmallVector<Value *> VecCs, ScalarOps;
1053+
std::optional<uint64_t> Index;
1054+
1055+
auto Ops = isa<IntrinsicInst>(I) ? cast<IntrinsicInst>(I).args()
1056+
: I.operand_values();
1057+
for (Value *Op : Ops) {
1058+
Constant *VecC;
1059+
Value *V;
1060+
uint64_t InsIdx = 0;
1061+
VectorType *OpTy = cast<VectorType>(Op->getType());
1062+
if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
1063+
m_ConstantInt(InsIdx)))) {
1064+
// Bail if any inserts are out of bounds.
1065+
if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1066+
return false;
1067+
// All inserts must have the same index.
1068+
if (!Index)
1069+
Index = InsIdx;
1070+
else if (InsIdx != *Index)
1071+
return false;
1072+
VecCs.push_back(VecC);
1073+
ScalarOps.push_back(V);
1074+
} else if (match(Op, m_Constant(VecC))) {
1075+
VecCs.push_back(VecC);
1076+
ScalarOps.push_back(nullptr);
1077+
} else {
1078+
return false;
1079+
}
1080+
}
10761081

1077-
auto *VecTy0 = cast<VectorType>(Ins0->getType());
1078-
auto *VecTy1 = cast<VectorType>(Ins1->getType());
1079-
if (VecTy0->getElementCount().getKnownMinValue() <= Index0 ||
1080-
VecTy1->getElementCount().getKnownMinValue() <= Index1)
1082+
// Bail if all operands are constant.
1083+
if (!Index.has_value())
10811084
return false;
10821085

1083-
uint64_t Index = IsConst0 ? Index1 : Index0;
1084-
Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
1085-
Type *VecTy = I.getType();
1086+
VectorType *VecTy = cast<VectorType>(I.getType());
1087+
Type *ScalarTy = VecTy->getScalarType();
10861088
assert(VecTy->isVectorTy() &&
1087-
(IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
10881089
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
10891090
ScalarTy->isPointerTy()) &&
10901091
"Unexpected types for insert element into binop or cmp");
@@ -1114,17 +1115,18 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11141115

11151116
// Get cost estimate for the insert element. This cost will factor into
11161117
// both sequences.
1117-
InstructionCost InsertCost = TTI.getVectorInstrCost(
1118-
Instruction::InsertElement, VecTy, CostKind, Index);
1119-
InstructionCost InsertCostV0 = TTI.getVectorInstrCost(
1120-
Instruction::InsertElement, VecTy, CostKind, Index, VecC0, V0);
1121-
InstructionCost InsertCostV1 = TTI.getVectorInstrCost(
1122-
Instruction::InsertElement, VecTy, CostKind, Index, VecC1, V1);
1123-
InstructionCost OldCost = (IsConst0 ? 0 : InsertCostV0) +
1124-
(IsConst1 ? 0 : InsertCostV1) + VectorOpCost;
1125-
InstructionCost NewCost = ScalarOpCost + InsertCost +
1126-
(IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCostV0) +
1127-
(IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCostV1);
1118+
InstructionCost OldCost = VectorOpCost;
1119+
InstructionCost NewCost =
1120+
ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
1121+
CostKind, *Index);
1122+
for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
1123+
if (!Scalar)
1124+
continue;
1125+
InstructionCost InsertCost = TTI.getVectorInstrCost(
1126+
Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1127+
OldCost += InsertCost;
1128+
NewCost += !Op->hasOneUse() * InsertCost;
1129+
}
11281130

11291131
// We want to scalarize unless the vector variant actually has lower cost.
11301132
if (OldCost < NewCost || !NewCost.isValid())
@@ -1140,19 +1142,20 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11401142
++NumScalarIntrinsic;
11411143

11421144
// For constant cases, extract the scalar element, this should constant fold.
1143-
if (IsConst0)
1144-
V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
1145-
if (IsConst1)
1146-
V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
1145+
for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
1146+
if (!Scalar)
1147+
ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1148+
cast<Constant>(VecC), Builder.getInt64(*Index));
11471149

11481150
Value *Scalar;
1149-
if (isa<CmpInst>(I))
1150-
Scalar = Builder.CreateCmp(Pred, V0, V1);
1151+
if (auto *CI = dyn_cast<CmpInst>(&I))
1152+
Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
11511153
else if (isa<BinaryOperator>(I))
1152-
Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
1154+
Scalar = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, ScalarOps[0],
1155+
ScalarOps[1]);
11531156
else
11541157
Scalar = Builder.CreateIntrinsic(
1155-
ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), {V0, V1});
1158+
ScalarTy, cast<IntrinsicInst>(I).getIntrinsicID(), ScalarOps);
11561159

11571160
Scalar->setName(I.getName() + ".scalar");
11581161

@@ -1163,14 +1166,14 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
11631166

11641167
// Fold the vector constants in the original vectors into a new base vector.
11651168
Value *NewVecC;
1166-
if (isa<CmpInst>(I))
1167-
NewVecC = Builder.CreateCmp(Pred, VecC0, VecC1);
1169+
if (auto *CI = dyn_cast<CmpInst>(&I))
1170+
NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
11681171
else if (isa<BinaryOperator>(I))
1169-
NewVecC = Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1);
1172+
NewVecC = Builder.CreateNAryOp(Opcode, VecCs);
11701173
else
11711174
NewVecC = Builder.CreateIntrinsic(
1172-
VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), {VecC0, VecC1});
1173-
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
1175+
VecTy, cast<IntrinsicInst>(I).getIntrinsicID(), VecCs);
1176+
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
11741177
replaceValue(I, *Insert);
11751178
return true;
11761179
}

llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,62 @@ define <4 x i32> @non_trivially_vectorizable(i32 %x, i32 %y) {
9696
ret <4 x i32> %v
9797
}
9898

99+
define <4 x float> @fabs_fixed(float %x) {
100+
; CHECK-LABEL: define <4 x float> @fabs_fixed(
101+
; CHECK-SAME: float [[X:%.*]]) {
102+
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
103+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fabs.v4f32(<4 x float> poison)
104+
; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
105+
; CHECK-NEXT: ret <4 x float> [[V]]
106+
;
107+
%x.insert = insertelement <4 x float> poison, float %x, i32 0
108+
%v = call <4 x float> @llvm.fabs(<4 x float> %x.insert)
109+
ret <4 x float> %v
110+
}
111+
112+
define <vscale x 4 x float> @fabs_scalable(float %x) {
113+
; CHECK-LABEL: define <vscale x 4 x float> @fabs_scalable(
114+
; CHECK-SAME: float [[X:%.*]]) {
115+
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fabs.f32(float [[X]])
116+
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fabs.nxv4f32(<vscale x 4 x float> poison)
117+
; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
118+
; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
119+
;
120+
%x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
121+
%v = call <vscale x 4 x float> @llvm.fabs(<vscale x 4 x float> %x.insert)
122+
ret <vscale x 4 x float> %v
123+
}
124+
125+
define <4 x float> @fma_fixed(float %x, float %y, float %z) {
126+
; CHECK-LABEL: define <4 x float> @fma_fixed(
127+
; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
128+
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
129+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.fma.v4f32(<4 x float> poison, <4 x float> poison, <4 x float> poison)
130+
; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
131+
; CHECK-NEXT: ret <4 x float> [[V]]
132+
;
133+
%x.insert = insertelement <4 x float> poison, float %x, i32 0
134+
%y.insert = insertelement <4 x float> poison, float %y, i32 0
135+
%z.insert = insertelement <4 x float> poison, float %z, i32 0
136+
%v = call <4 x float> @llvm.fma(<4 x float> %x.insert, <4 x float> %y.insert, <4 x float> %z.insert)
137+
ret <4 x float> %v
138+
}
139+
140+
define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
141+
; CHECK-LABEL: define <vscale x 4 x float> @fma_scalable(
142+
; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]]) {
143+
; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.fma.f32(float [[X]], float [[Y]], float [[Z]])
144+
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 4 x float> @llvm.fma.nxv4f32(<vscale x 4 x float> poison, <vscale x 4 x float> poison, <vscale x 4 x float> poison)
145+
; CHECK-NEXT: [[V:%.*]] = insertelement <vscale x 4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
146+
; CHECK-NEXT: ret <vscale x 4 x float> [[V]]
147+
;
148+
%x.insert = insertelement <vscale x 4 x float> poison, float %x, i32 0
149+
%y.insert = insertelement <vscale x 4 x float> poison, float %y, i32 0
150+
%z.insert = insertelement <vscale x 4 x float> poison, float %z, i32 0
151+
%v = call <vscale x 4 x float> @llvm.fma(<vscale x 4 x float> %x.insert, <vscale x 4 x float> %y.insert, <vscale x 4 x float> %z.insert)
152+
ret <vscale x 4 x float> %v
153+
}
154+
99155
; TODO: We should be able to scalarize this if we preserve the scalar argument.
100156
define <4 x float> @scalar_argument(float %x) {
101157
; CHECK-LABEL: define <4 x float> @scalar_argument(

0 commit comments

Comments
 (0)