Skip to content

Commit 4b6ed67

Browse files
Better handling of ElementCount
1 parent 285279b commit 4b6ed67

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

llvm/lib/CodeGen/ReplaceWithVeclib.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
//
99
// Replaces LLVM IR instructions with vector operands (i.e., the frem
1010
// instruction or calls to LLVM intrinsics) with matching calls to functions
11-
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface
11+
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
1212
//
1313
//===----------------------------------------------------------------------===//
1414

@@ -70,7 +70,7 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
7070
}
7171

7272
/// Replace the instruction \p I with a call to the corresponding function from
73-
/// the vector library ( \p TLIVecFunc ).
73+
/// the vector library (\p TLIVecFunc).
7474
static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
7575
Function *TLIVecFunc) {
7676
IRBuilder<> IRBuilder(&I);
@@ -100,26 +100,26 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
100100
/// works when \p I is a call to vectorized intrinsic or the frem instruction.
101101
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
102102
Instruction &I) {
103+
auto *VTy = dyn_cast<VectorType>(I.getType());
104+
ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
105+
// Compute the argument types of the corresponding scalar call and the scalar
106+
// function name. For calls, it additionally finds the function to replace
107+
// and checks that all vector operands match the previously found EC.
108+
SmallVector<Type *, 8> ScalarArgTypes;
103109
std::string ScalarName;
104-
ElementCount EC = ElementCount::getFixed(0);
105110
Function *FuncToReplace = nullptr;
106-
SmallVector<Type *, 8> ScalarArgTypes;
107-
// Compute the argument types of the corresponding scalar call, the scalar
108-
// function name, and EC. For calls, it additionally checks if in the vector
109-
// call, all vector operands have the same EC.
110111
if (auto *CI = dyn_cast<CallInst>(&I)) {
111-
Intrinsic::ID IID = CI->getCalledFunction()->getIntrinsicID();
112-
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
113112
FuncToReplace = CI->getCalledFunction();
113+
Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
114+
assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
114115
for (auto Arg : enumerate(CI->args())) {
115116
auto *ArgTy = Arg.value()->getType();
116117
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
117118
ScalarArgTypes.push_back(ArgTy);
118119
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
119120
ScalarArgTypes.push_back(VectorArgTy->getElementType());
120-
// Disallow vector arguments with different VFs. When processing the
121-
// first vector argument, store it's VF, and for the rest ensure that
122-
// they match it.
121+
// When return type is void, set EC to the first vector argument, and
122+
// disallow vector arguments with different ECs.
123123
if (EC.isZero())
124124
EC = VectorArgTy->getElementCount();
125125
else if (EC != VectorArgTy->getElementCount())
@@ -134,15 +134,13 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
134134
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
135135
: Intrinsic::getName(IID).str();
136136
} else {
137-
assert(I.getType()->isVectorTy() && "Instruction must use vectors");
137+
assert(VTy && "Return type must be a vector");
138+
auto *ScalarTy = VTy->getScalarType();
138139
LibFunc Func;
139-
auto *ScalarTy = I.getType()->getScalarType();
140140
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
141141
return false;
142142
ScalarName = TLI.getName(Func);
143143
ScalarArgTypes = {ScalarTy, ScalarTy};
144-
if (auto *VTy = dyn_cast<VectorType>(I.getType()))
145-
EC = VTy->getElementCount();
146144
}
147145

148146
// Try to find the mapping for the scalar version of this intrinsic and the
@@ -180,13 +178,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
180178
return true;
181179
}
182180

183-
/// Supported instructions \p I are either frem or CallInsts to intrinsics.
181+
/// Supported instruction \p I must be a vectorized frem or a call to an
182+
/// intrinsic that returns either void or a vector.
184183
static bool isSupportedInstruction(Instruction *I) {
184+
Type *Ty = I->getType();
185185
if (auto *CI = dyn_cast<CallInst>(I))
186-
return CI->getCalledFunction() &&
186+
return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
187187
CI->getCalledFunction()->getIntrinsicID() !=
188188
Intrinsic::not_intrinsic;
189-
if (I->getOpcode() == Instruction::FRem && I->getType()->isVectorTy())
189+
if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
190190
return true;
191191
return false;
192192
}

0 commit comments

Comments
 (0)