Skip to content

Commit b1fba56

Browse files
authored
[SVE] Don't require lookup when demangling vector function mappings (#72260)
We can determine the VF from a combination of the mangled name (which indicates the arguments that take vectors) and the element sizes of the arguments for the scalar function the mapping has been established for. The assert when demangling fails has been removed in favour of just not adding the mapping, which prevents the crash seen in #71892 This patch also stops using _LLVM_ as an ISA for scalable vector tests, since there aren't defined rules for the way vector arguments should be handled (e.g. packed vs. unpacked representation).
1 parent 3114bd3 commit b1fba56

File tree

10 files changed

+339
-151
lines changed

10 files changed

+339
-151
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
174174
///
175175
/// \param MangledName -> input string in the format
176176
/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
177-
/// \param M -> Module used to retrieve informations about the vector
178-
/// function that are not possible to retrieve from the mangled
179-
/// name. At the moment, this parameter is needed only to retrieve the
180-
/// Vectorization Factor of scalable vector functions from their
181-
/// respective IR declarations.
177+
/// \param CI -> A call to the scalar function which we're trying to find
178+
/// a vectorized variant for. This is required to determine the vectorization
179+
/// factor for scalable vectors, since the mangled name doesn't encode that;
180+
/// it needs to be derived from the widest element types of vector arguments
181+
/// or return values.
182182
std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
183-
const Module &M);
183+
const CallInst &CI);
184184

185185
/// Retrieve the `VFParamKind` from a string token.
186186
VFParamKind getVFParamKindFromString(const StringRef Token);
@@ -227,7 +227,7 @@ class VFDatabase {
227227
return;
228228
for (const auto &MangledName : ListOfStrings) {
229229
const std::optional<VFInfo> Shape =
230-
VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
230+
VFABI::tryDemangleForVFABI(MangledName, CI);
231231
// A match is found via scalar and vector names, and also by
232232
// ensuring that the variant described in the attribute has a
233233
// corresponding definition or declaration of the vector

llvm/lib/Analysis/VFABIDemangling.cpp

Lines changed: 135 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Analysis/VectorUtils.h"
10+
#include "llvm/Support/Debug.h"
11+
#include "llvm/Support/raw_ostream.h"
12+
#include <limits>
1013

1114
using namespace llvm;
1215

16+
#define DEBUG_TYPE "vfabi-demangling"
17+
1318
namespace {
1419
/// Utilities for the Vector Function ABI name parser.
1520

@@ -21,8 +26,9 @@ enum class ParseRet {
2126
};
2227

2328
/// Extracts the `<isa>` information from the mangled string, and
24-
/// sets the `ISA` accordingly.
25-
ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
29+
/// sets the `ISA` accordingly. If successful, the <isa> token is removed
30+
/// from the input string `MangledName`.
31+
static ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
2632
if (MangledName.empty())
2733
return ParseRet::Error;
2834

@@ -45,9 +51,9 @@ ParseRet tryParseISA(StringRef &MangledName, VFISAKind &ISA) {
4551
}
4652

4753
/// Extracts the `<mask>` information from the mangled string, and
48-
/// sets `IsMasked` accordingly. The input string `MangledName` is
49-
/// left unmodified.
50-
ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
54+
/// sets `IsMasked` accordingly. If successful, the <mask> token is removed
55+
/// from the input string `MangledName`.
56+
static ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
5157
if (MangledName.consume_front("M")) {
5258
IsMasked = true;
5359
return ParseRet::OK;
@@ -62,28 +68,36 @@ ParseRet tryParseMask(StringRef &MangledName, bool &IsMasked) {
6268
}
6369

6470
/// Extract the `<vlen>` information from the mangled string, and
65-
/// sets `VF` accordingly. A `<vlen> == "x"` token is interpreted as a scalable
66-
/// vector length. On success, the `<vlen>` token is removed from
67-
/// the input string `ParseString`.
68-
///
69-
ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
71+
/// sets `ParsedVF` accordingly. A `<vlen> == "x"` token is interpreted as a
72+
/// scalable vector length and the boolean is set to true, otherwise a nonzero
73+
/// unsigned integer will be directly used as a VF. On success, the `<vlen>`
74+
/// token is removed from the input string `ParseString`.
75+
static ParseRet tryParseVLEN(StringRef &ParseString, VFISAKind ISA,
76+
std::pair<unsigned, bool> &ParsedVF) {
7077
if (ParseString.consume_front("x")) {
71-
// Set VF to 0, to be later adjusted to a value grater than zero
72-
// by looking at the signature of the vector function with
73-
// `getECFromSignature`.
74-
VF = 0;
75-
IsScalable = true;
78+
// SVE is the only scalable ISA currently supported.
79+
if (ISA != VFISAKind::SVE) {
80+
LLVM_DEBUG(dbgs() << "Vector function variant declared with scalable VF "
81+
<< "but ISA is not SVE\n");
82+
return ParseRet::Error;
83+
}
84+
// We can't determine the VF of a scalable vector by looking at the vlen
85+
// string (just 'x'), so say we successfully parsed it but return a 'true'
86+
// for the scalable field with an invalid VF field so that we know to look
87+
// up the actual VF based on element types from the parameters or return.
88+
ParsedVF = {0, true};
7689
return ParseRet::OK;
7790
}
7891

92+
unsigned VF = 0;
7993
if (ParseString.consumeInteger(10, VF))
8094
return ParseRet::Error;
8195

8296
// The token `0` is invalid for VLEN.
8397
if (VF == 0)
8498
return ParseRet::Error;
8599

86-
IsScalable = false;
100+
ParsedVF = {VF, false};
87101
return ParseRet::OK;
88102
}
89103

@@ -99,9 +113,9 @@ ParseRet tryParseVLEN(StringRef &ParseString, unsigned &VF, bool &IsScalable) {
99113
///
100114
/// The function expects <token> to be one of "ls", "Rs", "Us" or
101115
/// "Ls".
102-
ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
103-
VFParamKind &PKind, int &Pos,
104-
const StringRef Token) {
116+
static ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
117+
VFParamKind &PKind, int &Pos,
118+
const StringRef Token) {
105119
if (ParseString.consume_front(Token)) {
106120
PKind = VFABI::getVFParamKindFromString(Token);
107121
if (ParseString.consumeInteger(10, Pos))
@@ -123,8 +137,9 @@ ParseRet tryParseLinearTokenWithRuntimeStep(StringRef &ParseString,
123137
/// sets `PKind` to the correspondent enum value, sets `StepOrPos` to
124138
/// <number>, and return success. On a syntax error, it return a
125139
/// parsing error. If nothing is parsed, it returns std::nullopt.
126-
ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
127-
VFParamKind &PKind, int &StepOrPos) {
140+
static ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
141+
VFParamKind &PKind,
142+
int &StepOrPos) {
128143
ParseRet Ret;
129144

130145
// "ls" <RuntimeStepPos>
@@ -162,9 +177,10 @@ ParseRet tryParseLinearWithRuntimeStep(StringRef &ParseString,
162177
///
163178
/// The function expects <token> to be one of "l", "R", "U" or
164179
/// "L".
165-
ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
166-
VFParamKind &PKind, int &LinearStep,
167-
const StringRef Token) {
180+
static ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
181+
VFParamKind &PKind,
182+
int &LinearStep,
183+
const StringRef Token) {
168184
if (ParseString.consume_front(Token)) {
169185
PKind = VFABI::getVFParamKindFromString(Token);
170186
const bool Negate = ParseString.consume_front("n");
@@ -187,8 +203,9 @@ ParseRet tryParseCompileTimeLinearToken(StringRef &ParseString,
187203
/// sets `PKind` to the correspondent enum value, sets `LinearStep` to
188204
/// <number>, and return success. On a syntax error, it return a
189205
/// parsing error. If nothing is parsed, it returns std::nullopt.
190-
ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
191-
VFParamKind &PKind, int &StepOrPos) {
206+
static ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
207+
VFParamKind &PKind,
208+
int &StepOrPos) {
192209
// "l" {"n"} <CompileTimeStep>
193210
if (tryParseCompileTimeLinearToken(ParseString, PKind, StepOrPos, "l") ==
194211
ParseRet::OK)
@@ -220,8 +237,8 @@ ParseRet tryParseLinearWithCompileTimeStep(StringRef &ParseString,
220237
/// sets `PKind` to the correspondent enum value, sets `StepOrPos`
221238
/// accordingly, and return success. On a syntax error, it return a
222239
/// parsing error. If nothing is parsed, it returns std::nullopt.
223-
ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
224-
int &StepOrPos) {
240+
static ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
241+
int &StepOrPos) {
225242
if (ParseString.consume_front("v")) {
226243
PKind = VFParamKind::Vector;
227244
StepOrPos = 0;
@@ -255,7 +272,7 @@ ParseRet tryParseParameter(StringRef &ParseString, VFParamKind &PKind,
255272
/// sets `PKind` to the correspondent enum value, sets `StepOrPos`
256273
/// accordingly, and return success. On a syntax error, it return a
257274
/// parsing error. If nothing is parsed, it returns std::nullopt.
258-
ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
275+
static ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
259276
uint64_t Val;
260277
// "a" <number>
261278
if (ParseString.consume_front("a")) {
@@ -273,49 +290,86 @@ ParseRet tryParseAlign(StringRef &ParseString, Align &Alignment) {
273290
return ParseRet::None;
274291
}
275292

276-
#ifndef NDEBUG
277-
// Verify the assumtion that all vectors in the signature of a vector
278-
// function have the same number of elements.
279-
bool verifyAllVectorsHaveSameWidth(FunctionType *Signature) {
280-
SmallVector<VectorType *, 2> VecTys;
281-
if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
282-
VecTys.push_back(RetTy);
283-
for (auto *Ty : Signature->params())
284-
if (auto *VTy = dyn_cast<VectorType>(Ty))
285-
VecTys.push_back(VTy);
286-
287-
if (VecTys.size() <= 1)
288-
return true;
289-
290-
assert(VecTys.size() > 1 && "Invalid number of elements.");
291-
const ElementCount EC = VecTys[0]->getElementCount();
292-
return llvm::all_of(llvm::drop_begin(VecTys), [&EC](VectorType *VTy) {
293-
return (EC == VTy->getElementCount());
294-
});
293+
// Returns the 'natural' VF for a given scalar element type, based on the
294+
// current architecture.
295+
//
296+
// For SVE (currently the only scalable architecture with a defined name
297+
// mangling), we assume a minimum vector size of 128b and return a VF based on
298+
// the number of elements of the given type which would fit in such a vector.
299+
static std::optional<ElementCount> getElementCountForTy(const VFISAKind ISA,
300+
const Type *Ty) {
301+
// Only AArch64 SVE is supported at present.
302+
assert(ISA == VFISAKind::SVE &&
303+
"Scalable VF decoding only implemented for SVE\n");
304+
305+
if (Ty->isIntegerTy(64) || Ty->isDoubleTy() || Ty->isPointerTy())
306+
return ElementCount::getScalable(2);
307+
if (Ty->isIntegerTy(32) || Ty->isFloatTy())
308+
return ElementCount::getScalable(4);
309+
if (Ty->isIntegerTy(16) || Ty->is16bitFPTy())
310+
return ElementCount::getScalable(8);
311+
if (Ty->isIntegerTy(8))
312+
return ElementCount::getScalable(16);
313+
314+
return std::nullopt;
295315
}
296-
#endif // NDEBUG
297-
298-
// Extract the VectorizationFactor from a given function signature,
299-
// under the assumtion that all vectors have the same number of
300-
// elements, i.e. same ElementCount.Min.
301-
ElementCount getECFromSignature(FunctionType *Signature) {
302-
assert(verifyAllVectorsHaveSameWidth(Signature) &&
303-
"Invalid vector signature.");
304-
305-
if (auto *RetTy = dyn_cast<VectorType>(Signature->getReturnType()))
306-
return RetTy->getElementCount();
307-
for (auto *Ty : Signature->params())
308-
if (auto *VTy = dyn_cast<VectorType>(Ty))
309-
return VTy->getElementCount();
310-
311-
return ElementCount::getFixed(/*Min=*/1);
316+
317+
// Extract the VectorizationFactor from a given function signature, based
318+
// on the widest scalar element types that will become vector parameters.
319+
static std::optional<ElementCount>
320+
getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
321+
const SmallVectorImpl<VFParameter> &Params) {
322+
// Start with a very wide EC and drop when we find smaller ECs based on type.
323+
ElementCount MinEC =
324+
ElementCount::getScalable(std::numeric_limits<unsigned int>::max());
325+
for (auto &Param : Params) {
326+
// Only vector parameters are used when determining the VF; uniform or
327+
// linear are left as scalars, so do not affect VF.
328+
if (Param.ParamKind == VFParamKind::Vector) {
329+
// If the scalar function doesn't actually have a corresponding argument,
330+
// reject the mapping.
331+
if (Param.ParamPos >= Signature->getNumParams())
332+
return std::nullopt;
333+
Type *PTy = Signature->getParamType(Param.ParamPos);
334+
335+
std::optional<ElementCount> EC = getElementCountForTy(ISA, PTy);
336+
// If we have an unknown scalar element type we can't find a reasonable
337+
// VF.
338+
if (!EC)
339+
return std::nullopt;
340+
341+
// Find the smallest VF, based on the widest scalar type.
342+
if (ElementCount::isKnownLT(*EC, MinEC))
343+
MinEC = *EC;
344+
}
345+
}
346+
347+
// Also check the return type if not void.
348+
Type *RetTy = Signature->getReturnType();
349+
if (!RetTy->isVoidTy()) {
350+
std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
351+
// If we have an unknown scalar element type we can't find a reasonable VF.
352+
if (!ReturnEC)
353+
return std::nullopt;
354+
if (ElementCount::isKnownLT(*ReturnEC, MinEC))
355+
MinEC = *ReturnEC;
356+
}
357+
358+
// The SVE Vector function call ABI bases the VF on the widest element types
359+
// present, and vector arguments containing types of that width are always
360+
// considered to be packed. Arguments with narrower elements are considered
361+
// to be unpacked.
362+
if (MinEC.getKnownMinValue() < std::numeric_limits<unsigned int>::max())
363+
return MinEC;
364+
365+
return std::nullopt;
312366
}
313367
} // namespace
314368

315369
// Format of the ABI name:
316370
// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
317371
std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
318-
const Module &M) {
372+
const CallInst &CI) {
319373
const StringRef OriginalName = MangledName;
320374
// Assume there is no custom name <redirection>, and therefore the
321375
// vector name consists of
@@ -338,9 +392,8 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
338392
return std::nullopt;
339393

340394
// Parse the variable size, starting from <vlen>.
341-
unsigned VF;
342-
bool IsScalable;
343-
if (tryParseVLEN(MangledName, VF, IsScalable) != ParseRet::OK)
395+
std::pair<unsigned, bool> ParsedVF;
396+
if (tryParseVLEN(MangledName, ISA, ParsedVF) != ParseRet::OK)
344397
return std::nullopt;
345398

346399
// Parse the <parameters>.
@@ -374,6 +427,19 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
374427
if (Parameters.empty())
375428
return std::nullopt;
376429

430+
// Figure out the number of lanes in vectors for this function variant. This
431+
// is easy for fixed length, as the vlen encoding just gives us the value
432+
// directly. However, if the vlen mangling indicated that this function
433+
// variant expects scalable vectors we need to work it out based on the
434+
// demangled parameter types and the scalar function signature.
435+
std::optional<ElementCount> EC;
436+
if (ParsedVF.second) {
437+
EC = getScalableECFromSignature(CI.getFunctionType(), ISA, Parameters);
438+
if (!EC)
439+
return std::nullopt;
440+
} else
441+
EC = ElementCount::getFixed(ParsedVF.first);
442+
377443
// Check for the <scalarname> and the optional <redirection>, which
378444
// are separated from the prefix with "_"
379445
if (!MangledName.consume_front("_"))
@@ -426,32 +492,7 @@ std::optional<VFInfo> VFABI::tryDemangleForVFABI(StringRef MangledName,
426492
assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate &&
427493
"The global predicate must be the last parameter");
428494

429-
// Adjust the VF for scalable signatures. The EC.Min is not encoded
430-
// in the name of the function, but it is encoded in the IR
431-
// signature of the function. We need to extract this information
432-
// because it is needed by the loop vectorizer, which reasons in
433-
// terms of VectorizationFactor or ElementCount. In particular, we
434-
// need to make sure that the VF field of the VFShape class is never
435-
// set to 0.
436-
if (IsScalable) {
437-
const Function *F = M.getFunction(VectorName);
438-
// The declaration of the function must be present in the module
439-
// to be able to retrieve its signature.
440-
if (!F)
441-
return std::nullopt;
442-
const ElementCount EC = getECFromSignature(F->getFunctionType());
443-
VF = EC.getKnownMinValue();
444-
}
445-
446-
// 1. We don't accept a zero lanes vectorization factor.
447-
// 2. We don't accept the demangling if the vector function is not
448-
// present in the module.
449-
if (VF == 0)
450-
return std::nullopt;
451-
if (!M.getFunction(VectorName))
452-
return std::nullopt;
453-
454-
const VFShape Shape({ElementCount::get(VF, IsScalable), Parameters});
495+
const VFShape Shape({*EC, Parameters});
455496
return VFInfo({Shape, std::string(ScalarName), std::string(VectorName), ISA});
456497
}
457498

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,15 +1466,13 @@ void VFABI::getVectorVariantNames(
14661466
S.split(ListAttr, ",");
14671467

14681468
for (const auto &S : SetVector<StringRef>(ListAttr.begin(), ListAttr.end())) {
1469-
#ifndef NDEBUG
1470-
LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << S << "'\n");
1471-
std::optional<VFInfo> Info =
1472-
VFABI::tryDemangleForVFABI(S, *(CI.getModule()));
1473-
assert(Info && "Invalid name for a VFABI variant.");
1474-
assert(CI.getModule()->getFunction(Info->VectorName) &&
1475-
"Vector function is missing.");
1476-
#endif
1477-
VariantMappings.push_back(std::string(S));
1469+
std::optional<VFInfo> Info = VFABI::tryDemangleForVFABI(S, CI);
1470+
if (Info && CI.getModule()->getFunction(Info->VectorName)) {
1471+
LLVM_DEBUG(dbgs() << "VFABI: Adding mapping '" << S << "' for " << CI
1472+
<< "\n");
1473+
VariantMappings.push_back(std::string(S));
1474+
} else
1475+
LLVM_DEBUG(dbgs() << "VFABI: Invalid mapping '" << S << "'\n");
14781476
}
14791477
}
14801478

0 commit comments

Comments
 (0)