Skip to content

Commit 97ea3b7

Browse files
createFunctionType requires only VFInfo and ScalarFTy
Not returning a pair anymore as the position can be queries directly from VFInfo, which createFunctionType needs to have as an argument to begin with. Also getting the return type from the ScalarFTy, as the specification does not encode in the mangled name such information. Therefore, the VFInfo does not hold such info. If that changes, then it will make its way into VFInfo and one could get it from there.
1 parent 29f11b1 commit 97ea3b7

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,12 @@ static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
196196
void getVectorVariantNames(const CallInst &CI,
197197
SmallVectorImpl<std::string> &VariantMappings);
198198

199-
/// Returns a pair of the vectorized FunctionType and the mask's position when
200-
/// there's one, otherwise -1.
201-
std::optional<std::pair<FunctionType *, int>>
202-
createFunctionType(const VFInfo &Info, const FunctionType *ScalarFTy,
203-
Type *VecRetTy, const Module *M);
199+
/// Returns a vectorized FunctionType that was previously found in
200+
/// TargetLibraryInfo. It uses \p ScalarFTy for the types, and \p Info to get
201+
/// the vectorization factor and whether a particular parameter is indeed a
202+
/// vector, since some of them may be scalars.
203+
std::optional<FunctionType *> createFunctionType(const VFInfo &Info,
204+
const FunctionType *ScalarFTy);
204205
} // end namespace VFABI
205206

206207
/// The Vector Function Database.

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Analysis/TargetTransformInfo.h"
2222
#include "llvm/Analysis/ValueTracking.h"
2323
#include "llvm/IR/Constants.h"
24+
#include "llvm/IR/DerivedTypes.h"
2425
#include "llvm/IR/IRBuilder.h"
2526
#include "llvm/IR/PatternMatch.h"
2627
#include "llvm/IR/Value.h"
@@ -1479,9 +1480,8 @@ void VFABI::getVectorVariantNames(
14791480
}
14801481
}
14811482

1482-
std::optional<std::pair<FunctionType *, int>>
1483-
VFABI::createFunctionType(const VFInfo &Info, const FunctionType *ScalarFTy,
1484-
Type *VecRetTy, const Module *M) {
1483+
std::optional<FunctionType *>
1484+
VFABI::createFunctionType(const VFInfo &Info, const FunctionType *ScalarFTy) {
14851485
ElementCount VF = Info.Shape.VF;
14861486
// Create vector parameter types
14871487
SmallVector<Type *, 8> VecTypes;
@@ -1500,11 +1500,15 @@ VFABI::createFunctionType(const VFInfo &Info, const FunctionType *ScalarFTy,
15001500
return std::nullopt;
15011501

15021502
MaskPos = OptMaskPos.value();
1503-
VectorType *MaskTy = VectorType::get(Type::getInt1Ty(M->getContext()), VF);
1503+
VectorType *MaskTy =
1504+
VectorType::get(Type::getInt1Ty(ScalarFTy->getContext()), VF);
15041505
VecTypes.insert(VecTypes.begin() + MaskPos, MaskTy);
15051506
}
1506-
FunctionType *VecFTy = FunctionType::get(VecRetTy, VecTypes, false);
1507-
return std::make_pair(VecFTy, MaskPos);
1507+
auto *RetTy = ScalarFTy->getReturnType();
1508+
if (!RetTy->isVoidTy())
1509+
RetTy = VectorType::get(ScalarFTy->getReturnType(), VF);
1510+
FunctionType *VecFTy = FunctionType::get(RetTy, VecTypes, false);
1511+
return VecFTy;
15081512
}
15091513

15101514
bool VFShape::hasValidParameterList() const {

llvm/unittests/Analysis/VectorFunctionABITest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,11 @@ class VFABIParserTest : public ::testing::Test {
126126

127127
// Use VFInfo and the mock CallInst to create a FunctionType that will
128128
// include a mask when relevant.
129-
auto OptVecFTyPos =
130-
VFABI::createFunctionType(Info, ScalarFTy, VecRetTy, M.get());
131-
if (!OptVecFTyPos)
129+
auto OptVecFTy = VFABI::createFunctionType(Info, ScalarFTy);
130+
if (!OptVecFTy)
132131
return false;
133132

134-
FunctionType *VecFTy = OptVecFTyPos->first;
133+
FunctionType *VecFTy = *OptVecFTy;
135134
// Check that vectorized parameters' size match with VFInfo.
136135
// Both may include a mask.
137136
if ((VecFTy->getNumParams() != Info.Shape.Parameters.size()))

0 commit comments

Comments
 (0)