Skip to content

Commit fc97e36

Browse files
[NFC][Clang][AArch64]Refactor implementation of Neon vectors MFloat8x8 and MFloat8x16
This patch removes the builtins for MFloat8x8 and Mfloat8x16 and build these types the same way the other neon vectors are build. It uses the scalar type(mfloat8).
1 parent 91aad9b commit fc97e36

File tree

14 files changed

+68
-41
lines changed

14 files changed

+68
-41
lines changed

clang/include/clang/AST/Type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,6 +2534,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25342534
bool isFloat32Type() const;
25352535
bool isDoubleType() const;
25362536
bool isBFloat16Type() const;
2537+
bool isMFloat8Type() const;
25372538
bool isFloat128Type() const;
25382539
bool isIbm128Type() const;
25392540
bool isRealType() const; // C99 6.2.5p17 (real floating + integer)
@@ -8542,6 +8543,10 @@ inline bool Type::isBFloat16Type() const {
85428543
return isSpecificBuiltinType(BuiltinType::BFloat16);
85438544
}
85448545

8546+
inline bool Type::isMFloat8Type() const {
8547+
return isSpecificBuiltinType(BuiltinType::MFloat8);
8548+
}
8549+
85458550
inline bool Type::isFloat128Type() const {
85468551
return isSpecificBuiltinType(BuiltinType::Float128);
85478552
}

clang/include/clang/Basic/AArch64SVEACLETypes.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", "svboolx4_t", SveBoolx4, SveBoolx4T
201201
SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
202202

203203
AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
204-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x8_t", "__MFloat8x8_t", MFloat8x8, MFloat8x8Ty, 8, 8, 1)
205-
AARCH64_VECTOR_TYPE_MFLOAT("__MFloat8x16_t", "__MFloat8x16_t", MFloat8x16, MFloat8x16Ty, 16, 8, 1)
206204

207205
#undef SVE_VECTOR_TYPE
208206
#undef SVE_VECTOR_TYPE_BFLOAT

clang/include/clang/Basic/TargetBuiltins.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ namespace clang {
200200
Float16,
201201
Float32,
202202
Float64,
203-
BFloat16
203+
BFloat16,
204+
MFloat8
204205
};
205206

206207
NeonTypeFlags(unsigned F) : Flags(F) {}
@@ -222,6 +223,7 @@ namespace clang {
222223
switch (getEltType()) {
223224
case Int8:
224225
case Poly8:
226+
case MFloat8:
225227
return 8;
226228
case Int16:
227229
case Float16:

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def OP_UNAVAILABLE : Operation {
218218
// h: half-float
219219
// d: double
220220
// b: bfloat16
221+
// m: mfloat8
221222
//
222223
// Typespec modifiers
223224
// ------------------

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3901,6 +3901,8 @@ static StringRef mangleAArch64VectorBase(const BuiltinType *EltType) {
39013901
return "Float64";
39023902
case BuiltinType::BFloat16:
39033903
return "Bfloat16";
3904+
case BuiltinType::MFloat8:
3905+
return "Mfloat8";
39043906
default:
39053907
llvm_unreachable("Unexpected vector element base type");
39063908
}

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6594,6 +6594,7 @@ static llvm::FixedVectorType *GetNeonType(CodeGenFunction *CGF,
65946594
switch (TypeFlags.getEltType()) {
65956595
case NeonTypeFlags::Int8:
65966596
case NeonTypeFlags::Poly8:
6597+
case NeonTypeFlags::MFloat8:
65976598
return llvm::FixedVectorType::get(CGF->Int8Ty, V1Ty ? 1 : (8 << IsQuad));
65986599
case NeonTypeFlags::Int16:
65996600
case NeonTypeFlags::Poly16:

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
647647
case Type::ExtVector:
648648
case Type::Vector: {
649649
const auto *VT = cast<VectorType>(Ty);
650+
if (VT->getElementType()->isMFloat8Type()) {
651+
ResultType = llvm::FixedVectorType::get(
652+
llvm::Type::getInt8Ty(getLLVMContext()), VT->getNumElements());
653+
break;
654+
}
650655
// An ext_vector_type of Bool is really a vector of bits.
651656
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
652657
? llvm::Type::getInt1Ty(getLLVMContext())

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,6 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
375375
NSRN = std::min(NSRN + 1, 8u);
376376
else {
377377
switch (BT->getKind()) {
378-
case BuiltinType::MFloat8x8:
379-
case BuiltinType::MFloat8x16:
380-
NSRN = std::min(NSRN + 1, 8u);
381-
break;
382378
case BuiltinType::SveBool:
383379
case BuiltinType::SveCount:
384380
NPRN = std::min(NPRN + 1, 4u);
@@ -620,8 +616,7 @@ bool AArch64ABIInfo::isHomogeneousAggregateBaseType(QualType Ty) const {
620616
// but with the difference that any floating-point type is allowed,
621617
// including __fp16.
622618
if (const BuiltinType *BT = Ty->getAs<BuiltinType>()) {
623-
if (BT->isFloatingPoint() || BT->getKind() == BuiltinType::MFloat8x16 ||
624-
BT->getKind() == BuiltinType::MFloat8x8)
619+
if (BT->isFloatingPoint())
625620
return true;
626621
} else if (const VectorType *VT = Ty->getAs<VectorType>()) {
627622
if (auto Kind = VT->getVectorKind();

clang/lib/Sema/SemaARM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
323323
switch (Flags.getEltType()) {
324324
case NeonTypeFlags::Int8:
325325
return Flags.isUnsigned() ? Context.UnsignedCharTy : Context.SignedCharTy;
326+
case NeonTypeFlags::MFloat8:
327+
return Context.MFloat8Ty;
326328
case NeonTypeFlags::Int16:
327329
return Flags.isUnsigned() ? Context.UnsignedShortTy : Context.ShortTy;
328330
case NeonTypeFlags::Int32:

clang/lib/Sema/SemaExpr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10199,6 +10199,11 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1019910199
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
1020010200
IsCompAssign);
1020110201

10202+
// Any operation with MFloat8 type is only possible with C intrinsics
10203+
if ((LHSVecType && LHSVecType->getElementType()->isMFloat8Type()) ||
10204+
(RHSVecType && RHSVecType->getElementType()->isMFloat8Type()))
10205+
return InvalidOperands(Loc, LHS, RHS);
10206+
1020210207
// AltiVec-style "vector bool op vector bool" combinations are allowed
1020310208
// for some operators but not others.
1020410209
if (!AllowBothBool && LHSVecType &&

clang/lib/Sema/SemaType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8240,7 +8240,8 @@ static bool isPermittedNeonBaseType(QualType &Ty, VectorKind VecKind, Sema &S) {
82408240
BTy->getKind() == BuiltinType::ULongLong ||
82418241
BTy->getKind() == BuiltinType::Float ||
82428242
BTy->getKind() == BuiltinType::Half ||
8243-
BTy->getKind() == BuiltinType::BFloat16;
8243+
BTy->getKind() == BuiltinType::BFloat16 ||
8244+
BTy->getKind() == BuiltinType::MFloat8;
82448245
}
82458246

82468247
static bool verifyValidIntegerConstantExpr(Sema &S, const ParsedAttr &Attr,

clang/test/CodeGen/arm-mfp8.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// CHECK-C-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr [[V_ADDR]], align 16
1616
// CHECK-C-NEXT: ret <16 x i8> [[TMP0]]
1717
//
18-
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_tu14__MFloat8x16_t(
18+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z21test_ret_mfloat8x16_t14__Mfloat8x16_t(
1919
// CHECK-CXX-SAME: <16 x i8> [[V:%.*]]) #[[ATTR0:[0-9]+]] {
2020
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
2121
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <16 x i8>, align 16
@@ -35,7 +35,7 @@ mfloat8x16_t test_ret_mfloat8x16_t(mfloat8x16_t v) {
3535
// CHECK-C-NEXT: [[TMP0:%.*]] = load <8 x i8>, ptr [[V_ADDR]], align 8
3636
// CHECK-C-NEXT: ret <8 x i8> [[TMP0]]
3737
//
38-
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_tu13__MFloat8x8_t(
38+
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z20test_ret_mfloat8x8_t13__Mfloat8x8_t(
3939
// CHECK-CXX-SAME: <8 x i8> [[V:%.*]]) #[[ATTR0]] {
4040
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
4141
// CHECK-CXX-NEXT: [[V_ADDR:%.*]] = alloca <8 x i8>, align 8

clang/test/Sema/arm-mfp8.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,20 @@ void test_vector_sve(svmfloat8_t a, svuint8_t c) {
4444
a / c; // sve-error {{cannot convert between vector type 'svuint8_t' (aka '__SVUint8_t') and vector type 'svmfloat8_t' (aka '__SVMfloat8_t') as implicit conversion would cause truncation}}
4545
}
4646

47-
4847
#include <arm_neon.h>
4948

5049
void test_vector(mfloat8x8_t a, mfloat8x16_t b, uint8x8_t c) {
51-
a + b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
52-
a - b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
53-
a * b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
54-
a / b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
55-
56-
a + c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
57-
a - c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
58-
a * c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
59-
a / c; // neon-error {{cannot convert between vector and non-scalar values ('mfloat8x8_t' (aka '__MFloat8x8_t') and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
60-
c + b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
61-
c - b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
62-
c * b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
63-
c / b; // neon-error {{cannot convert between vector and non-scalar values ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (aka '__MFloat8x16_t'))}}
50+
a + b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
51+
a - b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
52+
a * b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
53+
a / b; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
54+
55+
a + c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
56+
a - c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
57+
a * c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
58+
a / c; // neon-error {{invalid operands to binary expression ('mfloat8x8_t' (vector of 8 'mfloat8_t' values) and 'uint8x8_t' (vector of 8 'uint8_t' values))}}
59+
c + b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
60+
c - b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
61+
c * b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
62+
c / b; // neon-error {{invalid operands to binary expression ('uint8x8_t' (vector of 8 'uint8_t' values) and 'mfloat8x16_t' (vector of 16 'mfloat8_t' values))}}
6463
}

clang/utils/TableGen/NeonEmitter.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ enum EltType {
101101
Float16,
102102
Float32,
103103
Float64,
104-
BFloat16
104+
BFloat16,
105+
MFloat8
105106
};
106107

107108
} // end namespace NeonTypeFlags
@@ -143,14 +144,7 @@ class Type {
143144
private:
144145
TypeSpec TS;
145146

146-
enum TypeKind {
147-
Void,
148-
Float,
149-
SInt,
150-
UInt,
151-
Poly,
152-
BFloat16
153-
};
147+
enum TypeKind { Void, Float, SInt, UInt, Poly, BFloat16, MFloat8 };
154148
TypeKind Kind;
155149
bool Immediate, Constant, Pointer;
156150
// ScalarForMangling and NoManglingQ are really not suited to live here as
@@ -203,6 +197,7 @@ class Type {
203197
bool isLong() const { return isInteger() && ElementBitwidth == 64; }
204198
bool isVoid() const { return Kind == Void; }
205199
bool isBFloat16() const { return Kind == BFloat16; }
200+
bool isMFloat8() const { return Kind == MFloat8; }
206201
unsigned getNumElements() const { return Bitwidth / ElementBitwidth; }
207202
unsigned getSizeInBits() const { return Bitwidth; }
208203
unsigned getElementSizeInBits() const { return ElementBitwidth; }
@@ -657,6 +652,8 @@ std::string Type::str() const {
657652
S += "float";
658653
else if (isBFloat16())
659654
S += "bfloat";
655+
else if (isMFloat8())
656+
S += "mfloat";
660657
else
661658
S += "int";
662659

@@ -699,6 +696,9 @@ std::string Type::builtin_str() const {
699696
else if (isBFloat16()) {
700697
assert(ElementBitwidth == 16 && "BFloat16 can only be 16 bits");
701698
S += "y";
699+
} else if (isMFloat8()) {
700+
assert(ElementBitwidth == 8 && "MFloat8 can only be 8 bits");
701+
S += "m";
702702
} else
703703
switch (ElementBitwidth) {
704704
case 16: S += "h"; break;
@@ -758,6 +758,10 @@ unsigned Type::getNeonEnum() const {
758758
Base = (unsigned)NeonTypeFlags::BFloat16;
759759
}
760760

761+
if (isMFloat8()) {
762+
Base = (unsigned)NeonTypeFlags::MFloat8;
763+
}
764+
761765
if (Bitwidth == 128)
762766
Base |= (unsigned)NeonTypeFlags::QuadFlag;
763767
if (isInteger() && !isSigned())
@@ -779,6 +783,8 @@ Type Type::fromTypedefName(StringRef Name) {
779783
T.Kind = Poly;
780784
} else if (Name.consume_front("bfloat")) {
781785
T.Kind = BFloat16;
786+
} else if (Name.consume_front("mfloat")) {
787+
T.Kind = MFloat8;
782788
} else {
783789
assert(Name.starts_with("int"));
784790
Name = Name.drop_front(3);
@@ -879,6 +885,10 @@ void Type::applyTypespec(bool &Quad) {
879885
Kind = BFloat16;
880886
ElementBitwidth = 16;
881887
break;
888+
case 'm':
889+
Kind = MFloat8;
890+
ElementBitwidth = 8;
891+
break;
882892
default:
883893
llvm_unreachable("Unhandled type code!");
884894
}
@@ -993,6 +1003,9 @@ std::string Intrinsic::getInstTypeCode(Type T, ClassKind CK) const {
9931003
if (T.isBFloat16())
9941004
return "bf16";
9951005

1006+
if (T.isMFloat8())
1007+
return "mfp8";
1008+
9961009
if (T.isPoly())
9971010
typeCode = 'p';
9981011
else if (T.isInteger())
@@ -1030,7 +1043,7 @@ std::string Intrinsic::getBuiltinTypeStr() {
10301043

10311044
Type RetT = getReturnType();
10321045
if ((LocalCK == ClassI || LocalCK == ClassW) && RetT.isScalar() &&
1033-
!RetT.isFloating() && !RetT.isBFloat16())
1046+
!RetT.isFloating() && !RetT.isBFloat16() && !RetT.isMFloat8())
10341047
RetT.makeInteger(RetT.getElementSizeInBits(), false);
10351048

10361049
// Since the return value must be one type, return a vector type of the
@@ -2270,7 +2283,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
22702283
for (auto &TS : TDTypeVec) {
22712284
bool IsA64 = false;
22722285
Type T(TS, ".");
2273-
if (T.isDouble())
2286+
if (T.isDouble() || T.isMFloat8())
22742287
IsA64 = true;
22752288

22762289
if (InIfdef && !IsA64) {
@@ -2303,7 +2316,7 @@ static void emitNeonTypeDefs(const std::string& types, raw_ostream &OS) {
23032316
for (auto &TS : TDTypeVec) {
23042317
bool IsA64 = false;
23052318
Type T(TS, ".");
2306-
if (T.isDouble())
2319+
if (T.isDouble() || T.isMFloat8())
23072320
IsA64 = true;
23082321

23092322
if (InIfdef && !IsA64) {
@@ -2589,8 +2602,6 @@ void NeonEmitter::runVectorTypes(raw_ostream &OS) {
25892602

25902603
OS << "#if defined(__aarch64__) || defined(__arm64ec__)\n";
25912604
OS << "typedef __mfp8 mfloat8_t;\n";
2592-
OS << "typedef __MFloat8x8_t mfloat8x8_t;\n";
2593-
OS << "typedef __MFloat8x16_t mfloat8x16_t;\n";
25942605
OS << "typedef double float64_t;\n";
25952606
OS << "#endif\n\n";
25962607

@@ -2648,7 +2659,7 @@ __arm_set_fpm_lscale2(fpm_t __fpm, uint64_t __scale) {
26482659
26492660
)";
26502661

2651-
emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlhQhfQfdQd", OS);
2662+
emitNeonTypeDefs("cQcsQsiQilQlUcQUcUsQUsUiQUiUlQUlmQmhQhfQfdQd", OS);
26522663

26532664
emitNeonTypeDefs("bQb", OS);
26542665
OS << "#endif // __ARM_NEON_TYPES_H\n";

0 commit comments

Comments
 (0)