Skip to content

Commit 5d027a3

Browse files
[NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8x8 and MFloat8x16
This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types the same way the other neon vectors are build. It uses the scalar type(mfloat8).
1 parent fcfd643 commit 5d027a3

File tree

14 files changed

+70
-42
lines changed

14 files changed

+70
-42
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25212521
bool isFloat32Type() const;
25222522
bool isDoubleType() const;
25232523
bool isBFloat16Type() const;
2524+
bool isMFloat8Type() const;
25242525
bool isFloat128Type() const;
25252526
bool isIbm128Type() const;
25262527
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8527,6 +8528,10 @@ inline bool Type::isBFloat16Type() const {
85278528
return isSpecificBuiltinType(BuiltinType::BFloat16);
85288529
}
85298530

8531+
inline bool Type::isMFloat8Type() const {
8532+
return isSpecificBuiltinType(BuiltinType::MFloat8);
8533+
}
8534+
85308535
inline bool Type::isFloat128Type() const {
85318536
return isSpecificBuiltinType(BuiltinType::Float128);
85328537
}

clang/include/clang/Basic/AArch64SVEACLETypes.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
201201
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
202202

203203
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8_t", "__MFloat8_t", MFloat8, MFloat8Ty, 1, 8, 1)
204-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
205-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
206204

207205
#undef SVE_VECTOR_TYPE
208206
#undef SVE_VECTOR_TYPE_BFLOAT

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ namespace clang {
198198
Float16,
199199
Float32,
200200
Float64,
201-
BFloat16
201+
BFloat16,
202+
MFloat8
202203
};
203204

204205
NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -220,6 +221,7 @@ namespace clang {
220221
switch (getEltType()) {
221222
case Int8:
222223
case Poly8:
224+
case MFloat8:
223225
return 8;
224226
case Int16:
225227
case Float16:

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
218218
// h: half-float
219219
// d: double
220220
// b: bfloat16
221+
// m: mfloat8
221222
//
222223
// Typespec modifiers
223224
// ------------------

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3902,6 +3902,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
39023902
return "Float64";
39033903
case BuiltinType::BFloat16:
39043904
return "Bfloat16";
3905+
case BuiltinType::MFloat8:
3906+
return "Mfloat8";
39053907
default:
39063908
llvm_unreachable("Unexpected vector element base type");
39073909
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6513,6 +6513,8 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
65136513
case NeonTypeFlags::Int8:
65146514
case NeonTypeFlags::Poly8:
65156515
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
6516+
case NeonTypeFlags::MFloat8:
6517+
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
65166518
case NeonTypeFlags::Int16:
65176519
case NeonTypeFlags::Poly16:
65186520
return llvm::FixedVectorType::get(CGF->Int16Ty, V1Ty ? 1 : (4 << IsQuad));

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
647647
case Type::ExtVector:
648648
case Type::Vector: {
649649
const auto *VT = cast<VectorType>(Ty);
650+
if (VT->getElementType()->isMFloat8Type()) {
651+
ResultType = llvm::FixedVectorType::get(
652+
llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
653+
break;
654+
}
650655
// An ext_vector_type of Bool is really a vector of bits.
651656
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
652657
? llvm::Type::getInt1Ty(getLLVMContext())

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
375375
NSRN = std::min(NSRN + 1, 8u);
376376
else {
377377
switch (BT->getKind()) {
378-
case BuiltinType::MFloat8x8:
379-
case BuiltinType::MFloat8x16:
380-
NSRN = std::min(NSRN + 1, 8u);
381-
break;
382378
case BuiltinType::SveBool:
383379
case BuiltinType::SveCount:
384380
NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
620616
// but with the difference that any floating-point type is allowed,
621617
// including __fp16.
622618
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
623-
if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
624-
BT->getKind() == BuiltinType::MFloat8x8)
619+
if (BT->isFloatingPoint())
625620
return true;
626621
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
627622
if (auto Kind = VT->getVectorKind();

clang/lib/Sema/SemaARM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
323323
switch (Flags.getEltType()) {
324324
case NeonTypeFlags::Int8:
325325
return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
326+
case NeonTypeFlags::MFloat8:
327+
return Context.MFloat8Ty;
326328
case NeonTypeFlags::Int16:
327329
return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
328330
case NeonTypeFlags::Int32:

clang/lib/Sema/SemaExpr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10156,6 +10156,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1015610156
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
1015710157
IsCompAssign);
1015810158

10159+
// Any operation with MFloat8 type is only possible with C intrinsics
10160+
if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
10161+
(RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
10162+
return InvalidOperands(Loc, LHS, RHS);
10163+
1015910164
// AltiVec-style "vector bool op vector bool" combinations are allowed
1016010165
// for some operators but not others.
1016110166
if (!AllowBothBool && LHSVecType &&

clang/lib/Sema/SemaType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8180,7 +8180,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
81808180
BTy->getKind() == BuiltinType::ULongLong ||
81818181
BTy->getKind() == BuiltinType::Float ||
81828182
BTy->getKind() == BuiltinType::Half ||
8183-
BTy->getKind() == BuiltinType::BFloat16;
8183+
BTy->getKind() == BuiltinType::BFloat16 ||
8184+
BTy->getKind() == BuiltinType::MFloat8;
81848185
}
81858186

81868187
static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,

clang/test/CodeGen/arm-mfp8.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// CHECK-C-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
1616
// CHECK-C-NEXT: ret <16 x i8> [[TMP0]]
1717
//
18-
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
18+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
1919
// CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
2020
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
2121
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
3535
// CHECK-C-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
3636
// CHECK-C-NEXT: ret <8 x i8> [[TMP0]]
3737
//
38-
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
38+
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
3939
// CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
4040
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
4141
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8

clang/test/Sema/arm-mfp8.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,22 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
1111
a / c; // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
1212
}
1313

14-
1514
#include <arm_neon.h>
1615

1716
void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
18-
a + b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
19-
a - b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
20-
a * b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
21-
a / b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
22-
23-
a + c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
24-
a - c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
25-
a * c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
26-
a / c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
27-
c + b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
28-
c - b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
29-
c * b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
30-
c / b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
17+
a + b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
18+
a - b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
19+
a * b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
20+
a / b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
21+
22+
a + c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
23+
a - c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
24+
a * c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
25+
a / c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
26+
c + b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
27+
c - b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
28+
c * b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
29+
c / b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
3130
}
3231
__mfp8 test_static_cast_from_char(char in) {
3332
return static_cast<__mfp8>(in); // scalar-error {{static_cast from 'char' to '__mfp8' (aka '__MFloat8_t') is not allowed}}
@@ -60,4 +59,3 @@ void test(bool b) {
6059
u8 = mfp8; // scalar-error {{assigning to 'char' from incompatible type '__mfp8' (aka '__MFloat8_t')}}
6160
mfp8 + (b ? u8 : mfp8); // scalar-error {{incompatible operand types ('char' and '__mfp8' (aka '__MFloat8_t'))}}
6261
}
63-

clang/utils/TableGen/NeonEmitter.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ enum EltType {
101101
Float16,
102102
Float32,
103103
Float64,
104-
BFloat16
104+
BFloat16,
105+
MFloat8
105106
};
106107

107108
} // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
143144
private:
144145
TypeSpec TS;
145146

146-
enum TypeKind {
147-
Void,
148-
Float,
149-
SInt,
150-
UInt,
151-
Poly,
152-
BFloat16
153-
};
147+
enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
154148
TypeKind Kind;
155149
bool Immediate, Constant, Pointer;
156150
// ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
203197
bool isLong() const { return isInteger() && ElementBitwidth == 64; }
204198
bool isVoid() const { return Kind == Void; }
205199
bool isBFloat16() const { return Kind == BFloat16; }
200+
bool isMFloat8() const { return Kind == MFloat8; }
206201
unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
207202
unsigned getSizeInBits() const { return Bitwidth; }
208203
unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
657652
S += "float";
658653
else if (isBFloat16())
659654
S += "bfloat";
655+
else if (isMFloat8())
656+
S += "mfloat";
660657
else
661658
S += "int";
662659

@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
699696
else if (isBFloat16()) {
700697
assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
701698
S += "y";
699+
} else if (isMFloat8()) {
700+
assert(ElementBitwidth == 8 && "BFloat16 can only be 8 bits");
701+
S += "m";
702702
} else
703703
switch (ElementBitwidth) {
704704
case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
758758
Base = (unsigned)NeonTypeFlags::BFloat16;
759759
}
760760

761+
if (isMFloat8()) {
762+
Base = (unsigned)NeonTypeFlags::MFloat8;
763+
}
764+
761765
if (Bitwidth == 128)
762766
Base |= (unsigned)NeonTypeFlags::QuadFlag;
763767
if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
779783
T.Kind = Poly;
780784
} else if (Name.consume_front("bfloat")) {
781785
T.Kind = BFloat16;
786+
} else if (Name.consume_front("mfloat")) {
787+
T.Kind = MFloat8;
782788
} else {
783789
assert(Name.starts_with("int"));
784790
Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
879885
Kind = BFloat16;
880886
ElementBitwidth = 16;
881887
break;
888+
case 'm':
889+
Kind = MFloat8;
890+
ElementBitwidth = 8;
891+
break;
882892
default:
883893
llvm_unreachable("Unhandled type code!");
884894
}
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
9931003
if (T.isBFloat16())
9941004
return "bf16";
9951005

1006+
if (T.isMFloat8())
1007+
return "mfp8";
1008+
9961009
if (T.isPoly())
9971010
typeCode = 'p';
9981011
else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
10301043

10311044
Type RetT = getReturnType();
10321045
if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
1033-
!RetT.isFloating() && !RetT.isBFloat16())
1046+
!RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
10341047
RetT.makeInteger(RetT.getElementSizeInBits(), false);
10351048

10361049
// Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
22702283
for (auto &TS : TDTypeVec) {
22712284
bool IsA64 = false;
22722285
Type T(TS, ".");
2273-
if (T.isDouble())
2286+
if (T.isDouble() || T.isMFloat8())
22742287
IsA64 = true;
22752288

22762289
if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
23032316
for (auto &TS : TDTypeVec) {
23042317
bool IsA64 = false;
23052318
Type T(TS, ".");
2306-
if (T.isDouble())
2319+
if (T.isDouble() || T.isMFloat8())
23072320
IsA64 = true;
23082321

23092322
if (InIfdef && !IsA64) {
@@ -2589,8 +2602,7 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
25892602

25902603
OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
25912604
OS << "typedef __MFloat8_t __mfp8;\n";
2592-
OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
2593-
OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
2605+
OS << "typedef __mfp8 mfloat8_t;\n";
25942606
OS << "typedef double float64_t;\n";
25952607
OS << "#endif\n\n";
25962608

@@ -2648,7 +2660,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
26482660
26492661
)";
26502662

2651-
emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
2663+
emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
26522664

26532665
emitNeonTypeDefs("bQb", OS);
26542666
OS << "#endif // __ARM_NEON_TYPES_H\n";

0 commit comments

Comments
 (0)