Skip to content

Commit b4c4a53

Browse files
Improve testing for createFunctionType.
Refactored 'createFunctionType'
1 parent af468f3 commit b4c4a53

File tree

3 files changed

+151
-115
lines changed

3 files changed

+151
-115
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,8 @@ static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
196196
void getVectorVariantNames(const CallInst &CI,
197197
SmallVectorImpl<std::string> &VariantMappings);
198198

199-
/// Returns a vectorized FunctionType that was previously found in
200-
/// TargetLibraryInfo. It uses \p ScalarFTy for the types, and \p Info to get
201-
/// the vectorization factor and whether a particular parameter is indeed a
202-
/// vector, since some of them may be scalars.
199+
/// Constructs a FunctionType by applying vector function information to the
200+
/// type of a matching scalar function.
203201
FunctionType *createFunctionType(const VFInfo &Info,
204202
const FunctionType *ScalarFTy);
205203
} // end namespace VFABI

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "llvm/IR/PatternMatch.h"
2727
#include "llvm/IR/Value.h"
2828
#include "llvm/Support/CommandLine.h"
29-
#include <optional>
3029

3130
#define DEBUG_TYPE "vectorutils"
3231

@@ -1482,27 +1481,26 @@ void VFABI::getVectorVariantNames(
14821481

14831482
FunctionType *VFABI::createFunctionType(const VFInfo &Info,
14841483
const FunctionType *ScalarFTy) {
1485-
ElementCount VF = Info.Shape.VF;
14861484
// Create vector parameter types
14871485
SmallVector<Type *, 8> VecTypes;
1488-
for (auto [STy, VFParam] : zip(ScalarFTy->params(), Info.Shape.Parameters)) {
1486+
ElementCount VF = Info.Shape.VF;
1487+
for (auto [Idx, VFParam] : enumerate(Info.Shape.Parameters)) {
1488+
if (VFParam.ParamKind == VFParamKind::GlobalPredicate) {
1489+
VectorType *MaskTy =
1490+
VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF);
1491+
VecTypes.push_back(MaskTy);
1492+
continue;
1493+
}
1494+
1495+
Type *OperandTy = ScalarFTy->getParamType(Idx);
14891496
if (VFParam.ParamKind == VFParamKind::Vector)
1490-
VecTypes.push_back(VectorType::get(STy, VF));
1491-
else
1492-
VecTypes.push_back(STy);
1497+
OperandTy = VectorType::get(OperandTy, VF);
1498+
VecTypes.push_back(OperandTy);
14931499
}
14941500

1495-
// Get mask's position mask and append one if not present in the Instruction.
1496-
if (auto OptMaskPos = Info.getParamIndexForOptionalMask()) {
1497-
if (!OptMaskPos)
1498-
return nullptr;
1499-
VectorType *MaskTy =
1500-
VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF);
1501-
VecTypes.insert(VecTypes.begin() + OptMaskPos.value(), MaskTy);
1502-
}
15031501
auto *RetTy = ScalarFTy->getReturnType();
15041502
if (!RetTy->isVoidTy())
1505-
RetTy = VectorType::get(ScalarFTy->getReturnType(), VF);
1503+
RetTy = VectorType::get(RetTy, VF);
15061504
return FunctionType::get(RetTy, VecTypes, false);
15071505
}
15081506

0 commit comments

Comments
 (0)