Skip to content

Commit ddb6db4

Browse files
[VFABI] Create FunctionType for vector functions (#75058)
`createFunctionType` returns a FunctionType that may contain a mask, which is currently placed as the last parameter to the Function. The placement happens according to `VFParameters` of `VFInfo`, and it should be able to handle VFABI specification changes. Regarding the return type, it uses the scalar type of the input instruction, as the specification does not encode in the mangled name such information. If that ever happens, that information should be available from `VFInfo`.
1 parent 67fd4e3 commit ddb6db4

File tree

4 files changed

+211
-33
lines changed

4 files changed

+211
-33
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
195195
/// the presence of the attribute (see InjectTLIMappings).
196196
void getVectorVariantNames(const CallInst &CI,
197197
SmallVectorImpl<std::string> &VariantMappings);
198+
199+
/// Constructs a FunctionType by applying vector function information to the
200+
/// type of a matching scalar function.
201+
/// \param Info gets the vectorization factor (VF) and the VFParamKind of the
202+
/// parameters.
203+
/// \param ScalarFTy gets the Type information of parameters, as it is not
204+
/// stored in \p Info.
205+
/// \returns a pointer to a newly created vector FunctionType
206+
FunctionType *createFunctionType(const VFInfo &Info,
207+
const FunctionType *ScalarFTy);
198208
} // end namespace VFABI
199209

200210
/// The Vector Function Database.

llvm/lib/Analysis/VFABIDemangling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
376376
// _ZGV<isa><mask><vlen><parameters>_<scalarname>.
377377
StringRef VectorName = MangledName;
378378

379-
// Parse the fixed size part of the manled name
379+
// Parse the fixed size part of the mangled name
380380
if (!MangledName.consume_front("_ZGV"))
381381
return std::nullopt;
382382

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "llvm/Analysis/VectorUtils.h"
1414
#include "llvm/ADT/EquivalenceClasses.h"
15+
#include "llvm/ADT/SmallVector.h"
1516
#include "llvm/Analysis/DemandedBits.h"
1617
#include "llvm/Analysis/LoopInfo.h"
1718
#include "llvm/Analysis/LoopIterator.h"
@@ -20,6 +21,7 @@
2021
#include "llvm/Analysis/TargetTransformInfo.h"
2122
#include "llvm/Analysis/ValueTracking.h"
2223
#include "llvm/IR/Constants.h"
24+
#include "llvm/IR/DerivedTypes.h"
2325
#include "llvm/IR/IRBuilder.h"
2426
#include "llvm/IR/PatternMatch.h"
2527
#include "llvm/IR/Value.h"
@@ -1477,6 +1479,32 @@ void VFABI::getVectorVariantNames(
14771479
}
14781480
}
14791481

1482+
FunctionType *VFABI::createFunctionType(const VFInfo &Info,
1483+
const FunctionType *ScalarFTy) {
1484+
// Create vector parameter types
1485+
SmallVector<Type *, 8> VecTypes;
1486+
ElementCount VF = Info.Shape.VF;
1487+
int ScalarParamIndex = 0;
1488+
for (auto VFParam : Info.Shape.Parameters) {
1489+
if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
1490+
VectorType *MaskTy =
1491+
VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF);
1492+
VecTypes.push_back(MaskTy);
1493+
continue;
1494+
}
1495+
1496+
Type *OperandTy = ScalarFTy->getParamType(ScalarParamIndex++);
1497+
if (VFParam.ParamKind == VFParamKind::Vector)
1498+
OperandTy = VectorType::get(OperandTy, VF);
1499+
VecTypes.push_back(OperandTy);
1500+
}
1501+
1502+
auto *RetTy = ScalarFTy->getReturnType();
1503+
if (!RetTy->isVoidTy())
1504+
RetTy = VectorType::get(RetTy, VF);
1505+
return FunctionType::get(RetTy, VecTypes, false);
1506+
}
1507+
14801508
bool VFShape::hasValidParameterList() const {
14811509
for (unsigned Pos = 0, NumParams = Parameters.size(); Pos < NumParams;
14821510
++Pos) {

0 commit comments

Comments
 (0)