Skip to content

Commit 73d59dd

Browse files
[VFABI] Create FunctionType for vector functions
`createFunctionType` optionally returns a FunctionType and the mask's position when there's one. It requires VFInfo and an Instruction. Add `checkFunctionType` in 'VectorFunctionABITest.cpp' tests to check that both the number and the type of vectorized parameters matches the created `FunctionType`.
1 parent ae7bffd commit 73d59dd

File tree

4 files changed

+181
-30
lines changed

4 files changed

+181
-30
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
195195
/// the presence of the attribute (see InjectTLIMappings).
196196
void getVectorVariantNames(const CallInst &CI,
197197
SmallVectorImpl<std::string> &VariantMappings);
198+
199+
/// Returns a pair of the vectorized FunctionType and the mask's position when
200+
/// there's one, otherwise -1. It rejects any non vectorized calls as this
201+
/// method should be called at a point where the Instruction \p I is already
202+
/// vectorized.
203+
std::optional<std::pair<FunctionType *, int>>
204+
createFunctionType(const VFInfo &Info, const Instruction *I, const Module *M);
198205
} // end namespace VFABI
199206

200207
/// The Vector Function Database.

llvm/lib/Analysis/VFABIDemangling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
376376
// _ZGV<isa><mask><vlen><parameters>_<scalarname>.
377377
StringRef VectorName = MangledName;
378378

379-
// Parse the fixed size part of the manled name
379+
// Parse the fixed size part of the mangled name
380380
if (!MangledName.consume_front("_ZGV"))
381381
return std::nullopt;
382382

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "llvm/Analysis/VectorUtils.h"
1414
#include "llvm/ADT/EquivalenceClasses.h"
15+
#include "llvm/ADT/SmallVector.h"
1516
#include "llvm/Analysis/DemandedBits.h"
1617
#include "llvm/Analysis/LoopInfo.h"
1718
#include "llvm/Analysis/LoopIterator.h"
@@ -24,6 +25,7 @@
2425
#include "llvm/IR/PatternMatch.h"
2526
#include "llvm/IR/Value.h"
2627
#include "llvm/Support/CommandLine.h"
28+
#include <optional>
2729

2830
#define DEBUG_TYPE "vectorutils"
2931

@@ -1477,6 +1479,49 @@ void VFABI::getVectorVariantNames(
14771479
}
14781480
}
14791481

1482+
// Returns whether any of the operands or return type of \p I are vectors.
1483+
static bool isVectorized(const Instruction *I) {
1484+
if (I->getType()->isVectorTy())
1485+
return true;
1486+
for (auto &U : I->operands())
1487+
if (U->getType()->isVectorTy())
1488+
return true;
1489+
return false;
1490+
}
1491+
1492+
std::optional<std::pair<FunctionType *, int>>
1493+
VFABI::createFunctionType(const VFInfo &Info, const Instruction *I,
1494+
const Module *M) {
1495+
// only vectorized calls should reach this method
1496+
if (!isVectorized(I))
1497+
return std::nullopt;
1498+
1499+
ElementCount VF = Info.Shape.VF;
1500+
// get vectorized operands
1501+
const bool IsCall = isa<CallBase>(I);
1502+
SmallVector<Type *, 8> VecParams;
1503+
for (auto [i, U] : enumerate(I->operands())) {
1504+
// ignore the function pointer when the Instruction is a call
1505+
if (IsCall && i == I->getNumOperands() - 1)
1506+
break;
1507+
VecParams.push_back(U->getType());
1508+
}
1509+
1510+
// Append a mask and get its position.
1511+
int MaskPos = -1;
1512+
if (Info.isMasked()) {
1513+
auto OptMaskPos = Info.getParamIndexForOptionalMask();
1514+
if (!OptMaskPos)
1515+
return std::nullopt;
1516+
1517+
MaskPos = OptMaskPos.value();
1518+
VectorType *MaskTy = VectorType::get(Type::getInt1Ty(M->getContext()), VF);
1519+
VecParams.insert(VecParams.begin() + MaskPos, MaskTy);
1520+
}
1521+
FunctionType *VecFTy = FunctionType::get(I->getType(), VecParams, false);
1522+
return std::make_pair(VecFTy, MaskPos);
1523+
}
1524+
14801525
bool VFShape::hasValidParameterList() const {
14811526
for (unsigned Pos = 0, NumParams = Parameters.size(); Pos < NumParams;
14821527
++Pos) {

0 commit comments

Comments
 (0)