Skip to content

Commit 635d20e

Browse files
authored
[RISCV] full support for riscv_rvv_vector_bits attribute (llvm#100110)
Add support for using attribute((rvv_vector_bits(N))), when N < 8. It allows using all fixed length vector mask types regardless VLEN value.
1 parent 7a51dde commit 635d20e

16 files changed

+586
-99
lines changed

clang/include/clang/AST/Type.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3989,6 +3989,10 @@ enum class VectorKind {
39893989

39903990
/// is RISC-V RVV fixed-length mask vector
39913991
RVVFixedLengthMask,
3992+
3993+
RVVFixedLengthMask_1,
3994+
RVVFixedLengthMask_2,
3995+
RVVFixedLengthMask_4
39923996
};
39933997

39943998
/// Represents a GCC generic vector type. This type is created using

clang/lib/AST/ASTContext.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
19891989
// Adjust the alignment for fixed-length SVE predicates.
19901990
Align = 16;
19911991
else if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
1992-
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
1992+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask ||
1993+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_1 ||
1994+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_2 ||
1995+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_4)
19931996
// Adjust the alignment for fixed-length RVV vectors.
19941997
Align = std::min<unsigned>(64, Width);
19951998
break;
@@ -9922,7 +9925,13 @@ bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
99229925
First->getVectorKind() != VectorKind::RVVFixedLengthData &&
99239926
Second->getVectorKind() != VectorKind::RVVFixedLengthData &&
99249927
First->getVectorKind() != VectorKind::RVVFixedLengthMask &&
9925-
Second->getVectorKind() != VectorKind::RVVFixedLengthMask)
9928+
Second->getVectorKind() != VectorKind::RVVFixedLengthMask &&
9929+
First->getVectorKind() != VectorKind::RVVFixedLengthMask_1 &&
9930+
Second->getVectorKind() != VectorKind::RVVFixedLengthMask_1 &&
9931+
First->getVectorKind() != VectorKind::RVVFixedLengthMask_2 &&
9932+
Second->getVectorKind() != VectorKind::RVVFixedLengthMask_2 &&
9933+
First->getVectorKind() != VectorKind::RVVFixedLengthMask_4 &&
9934+
Second->getVectorKind() != VectorKind::RVVFixedLengthMask_4)
99269935
return true;
99279936

99289937
return false;
@@ -10040,7 +10049,25 @@ bool ASTContext::areCompatibleRVVTypes(QualType FirstType,
1004010049
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
1004110050
return FirstType->isRVVVLSBuiltinType() &&
1004210051
Info.ElementType == BoolTy &&
10043-
getTypeSize(SecondType) == getRVVTypeSize(*this, BT);
10052+
getTypeSize(SecondType) == ((getRVVTypeSize(*this, BT)));
10053+
}
10054+
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask_1) {
10055+
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
10056+
return FirstType->isRVVVLSBuiltinType() &&
10057+
Info.ElementType == BoolTy &&
10058+
getTypeSize(SecondType) == ((getRVVTypeSize(*this, BT) * 8));
10059+
}
10060+
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask_2) {
10061+
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
10062+
return FirstType->isRVVVLSBuiltinType() &&
10063+
Info.ElementType == BoolTy &&
10064+
getTypeSize(SecondType) == ((getRVVTypeSize(*this, BT)) * 4);
10065+
}
10066+
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask_4) {
10067+
BuiltinVectorTypeInfo Info = getBuiltinVectorTypeInfo(BT);
10068+
return FirstType->isRVVVLSBuiltinType() &&
10069+
Info.ElementType == BoolTy &&
10070+
getTypeSize(SecondType) == ((getRVVTypeSize(*this, BT)) * 2);
1004410071
}
1004510072
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
1004610073
VT->getVectorKind() == VectorKind::Generic)

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4011,7 +4011,10 @@ void CXXNameMangler::mangleAArch64FixedSveVectorType(
40114011

40124012
void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
40134013
assert((T->getVectorKind() == VectorKind::RVVFixedLengthData ||
4014-
T->getVectorKind() == VectorKind::RVVFixedLengthMask) &&
4014+
T->getVectorKind() == VectorKind::RVVFixedLengthMask ||
4015+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_1 ||
4016+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_2 ||
4017+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_4) &&
40154018
"expected fixed-length RVV vector!");
40164019

40174020
QualType EltType = T->getElementType();
@@ -4062,7 +4065,21 @@ void CXXNameMangler::mangleRISCVFixedRVVVectorType(const VectorType *T) {
40624065
llvm_unreachable("unexpected element type for fixed-length RVV vector!");
40634066
}
40644067

4065-
unsigned VecSizeInBits = getASTContext().getTypeInfo(T).Width;
4068+
unsigned VecSizeInBits;
4069+
switch (T->getVectorKind()) {
4070+
case VectorKind::RVVFixedLengthMask_1:
4071+
VecSizeInBits = 1;
4072+
break;
4073+
case VectorKind::RVVFixedLengthMask_2:
4074+
VecSizeInBits = 2;
4075+
break;
4076+
case VectorKind::RVVFixedLengthMask_4:
4077+
VecSizeInBits = 4;
4078+
break;
4079+
default:
4080+
VecSizeInBits = getASTContext().getTypeInfo(T).Width;
4081+
break;
4082+
}
40664083

40674084
// Apend the LMUL suffix.
40684085
auto VScale = getASTContext().getTargetInfo().getVScaleRange(
@@ -4118,7 +4135,10 @@ void CXXNameMangler::mangleType(const VectorType *T) {
41184135
mangleAArch64FixedSveVectorType(T);
41194136
return;
41204137
} else if (T->getVectorKind() == VectorKind::RVVFixedLengthData ||
4121-
T->getVectorKind() == VectorKind::RVVFixedLengthMask) {
4138+
T->getVectorKind() == VectorKind::RVVFixedLengthMask ||
4139+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_1 ||
4140+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_2 ||
4141+
T->getVectorKind() == VectorKind::RVVFixedLengthMask_4) {
41224142
mangleRISCVFixedRVVVectorType(T);
41234143
return;
41244144
}

clang/lib/AST/JSONNodeDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,9 @@ void JSONNodeDumper::VisitVectorType(const VectorType *VT) {
737737
JOS.attribute("vectorKind", "fixed-length rvv data vector");
738738
break;
739739
case VectorKind::RVVFixedLengthMask:
740+
case VectorKind::RVVFixedLengthMask_1:
741+
case VectorKind::RVVFixedLengthMask_2:
742+
case VectorKind::RVVFixedLengthMask_4:
740743
JOS.attribute("vectorKind", "fixed-length rvv mask vector");
741744
break;
742745
}

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,9 @@ void TextNodeDumper::VisitVectorType(const VectorType *T) {
18591859
OS << " fixed-length rvv data vector";
18601860
break;
18611861
case VectorKind::RVVFixedLengthMask:
1862+
case VectorKind::RVVFixedLengthMask_1:
1863+
case VectorKind::RVVFixedLengthMask_2:
1864+
case VectorKind::RVVFixedLengthMask_4:
18621865
OS << " fixed-length rvv mask vector";
18631866
break;
18641867
}

clang/lib/AST/TypePrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,9 @@ void TypePrinter::printVectorBefore(const VectorType *T, raw_ostream &OS) {
721721
break;
722722
case VectorKind::RVVFixedLengthData:
723723
case VectorKind::RVVFixedLengthMask:
724+
case VectorKind::RVVFixedLengthMask_1:
725+
case VectorKind::RVVFixedLengthMask_2:
726+
case VectorKind::RVVFixedLengthMask_4:
724727
// FIXME: We prefer to print the size directly here, but have no way
725728
// to get the size of the type.
726729
OS << "__attribute__((__riscv_rvv_vector_bits__(";
@@ -801,6 +804,9 @@ void TypePrinter::printDependentVectorBefore(
801804
break;
802805
case VectorKind::RVVFixedLengthData:
803806
case VectorKind::RVVFixedLengthMask:
807+
case VectorKind::RVVFixedLengthMask_1:
808+
case VectorKind::RVVFixedLengthMask_2:
809+
case VectorKind::RVVFixedLengthMask_4:
804810
// FIXME: We prefer to print the size directly here, but have no way
805811
// to get the size of the type.
806812
OS << "__attribute__((__riscv_rvv_vector_bits__(";

clang/lib/CodeGen/Targets/RISCV.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,20 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty) const {
327327
getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts());
328328

329329
unsigned NumElts = VT->getNumElements();
330-
llvm::Type *EltType;
331-
if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) {
330+
llvm::Type *EltType = llvm::Type::getInt1Ty(getVMContext());
331+
switch (VT->getVectorKind()) {
332+
case VectorKind::RVVFixedLengthMask_1:
333+
break;
334+
case VectorKind::RVVFixedLengthMask_2:
335+
NumElts *= 2;
336+
break;
337+
case VectorKind::RVVFixedLengthMask_4:
338+
NumElts *= 4;
339+
break;
340+
case VectorKind::RVVFixedLengthMask:
332341
NumElts *= 8;
333-
EltType = llvm::Type::getInt1Ty(getVMContext());
334-
} else {
342+
break;
343+
default:
335344
assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData &&
336345
"Unexpected vector kind");
337346
EltType = CGT.ConvertType(VT->getElementType());
@@ -453,7 +462,10 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
453462

454463
if (const VectorType *VT = Ty->getAs<VectorType>())
455464
if (VT->getVectorKind() == VectorKind::RVVFixedLengthData ||
456-
VT->getVectorKind() == VectorKind::RVVFixedLengthMask)
465+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask ||
466+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_1 ||
467+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_2 ||
468+
VT->getVectorKind() == VectorKind::RVVFixedLengthMask_4)
457469
return coerceVLSVector(Ty);
458470

459471
// Aggregates which are <= 2*XLen will be passed in registers if possible,

clang/lib/Sema/SemaExpr.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10135,7 +10135,10 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1013510135
VecType->getVectorKind() == VectorKind::SveFixedLengthPredicate)
1013610136
return true;
1013710137
if (VecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
10138-
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
10138+
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask ||
10139+
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask_1 ||
10140+
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask_2 ||
10141+
VecType->getVectorKind() == VectorKind::RVVFixedLengthMask_4) {
1013910142
SVEorRVV = 1;
1014010143
return true;
1014110144
}
@@ -10167,7 +10170,13 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1016710170
VectorKind::SveFixedLengthPredicate)
1016810171
return true;
1016910172
if (SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthData ||
10170-
SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthMask) {
10173+
SecondVecType->getVectorKind() == VectorKind::RVVFixedLengthMask ||
10174+
SecondVecType->getVectorKind() ==
10175+
VectorKind::RVVFixedLengthMask_1 ||
10176+
SecondVecType->getVectorKind() ==
10177+
VectorKind::RVVFixedLengthMask_2 ||
10178+
SecondVecType->getVectorKind() ==
10179+
VectorKind::RVVFixedLengthMask_4) {
1017110180
SVEorRVV = 1;
1017210181
return true;
1017310182
}

clang/lib/Sema/SemaType.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8355,14 +8355,28 @@ static void HandleRISCVRVVVectorBitsTypeAttr(QualType &CurType,
83558355
unsigned NumElts;
83568356
if (Info.ElementType == S.Context.BoolTy) {
83578357
NumElts = VecSize / S.Context.getCharWidth();
8358-
VecKind = VectorKind::RVVFixedLengthMask;
8358+
if (!NumElts) {
8359+
NumElts = 1;
8360+
switch (VecSize) {
8361+
case 1:
8362+
VecKind = VectorKind::RVVFixedLengthMask_1;
8363+
break;
8364+
case 2:
8365+
VecKind = VectorKind::RVVFixedLengthMask_2;
8366+
break;
8367+
case 4:
8368+
VecKind = VectorKind::RVVFixedLengthMask_4;
8369+
break;
8370+
}
8371+
} else
8372+
VecKind = VectorKind::RVVFixedLengthMask;
83598373
} else {
83608374
ExpectedSize *= EltSize;
83618375
NumElts = VecSize / EltSize;
83628376
}
83638377

83648378
// The attribute vector size must match -mrvv-vector-bits.
8365-
if (ExpectedSize % 8 != 0 || VecSize != ExpectedSize) {
8379+
if (VecSize != ExpectedSize) {
83668380
S.Diag(Attr.getLoc(), diag::err_attribute_bad_rvv_vector_size)
83678381
<< VecSize << ExpectedSize;
83688382
Attr.setInvalid();

0 commit comments

Comments
 (0)