Skip to content

Commit 1dae495

Browse files
committed
[Clang][AArch64] Add fp8 variants for untyped NEON intrinsics
This patch adds fp8 variants to existing intrinsics, whose operation doesn't depend on arguments being a specific type.
1 parent 8aa3ab1 commit 1dae495

File tree

5 files changed

+1256
-5
lines changed

5 files changed

+1256
-5
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,17 +2090,17 @@ let ArchGuard = "defined(__aarch64__) || defined(__arm64ec__)", TargetGuard = "r
20902090

20912091
// Lookup table read with 2-bit/4-bit indices
20922092
let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
2093-
def VLUTI2_B : SInst<"vluti2_lane", "Q.(qU)I", "cUcPcQcQUcQPc",
2093+
def VLUTI2_B : SInst<"vluti2_lane", "Q.(qU)I", "cUcPcmQcQUcQPcQm",
20942094
[ImmCheck<2, ImmCheck0_1>]>;
2095-
def VLUTI2_B_Q : SInst<"vluti2_laneq", "Q.(QU)I", "cUcPcQcQUcQPc",
2095+
def VLUTI2_B_Q : SInst<"vluti2_laneq", "Q.(QU)I", "cUcPcmQcQUcQPcQm",
20962096
[ImmCheck<2, ImmCheck0_3>]>;
20972097
def VLUTI2_H : SInst<"vluti2_lane", "Q.(<qU)I", "sUsPshQsQUsQPsQh",
20982098
[ImmCheck<2, ImmCheck0_3>]>;
20992099
def VLUTI2_H_Q : SInst<"vluti2_laneq", "Q.(<QU)I", "sUsPshQsQUsQPsQh",
21002100
[ImmCheck<2, ImmCheck0_7>]>;
2101-
def VLUTI4_B : SInst<"vluti4_lane", "..(qU)I", "QcQUcQPc",
2101+
def VLUTI4_B : SInst<"vluti4_lane", "..(qU)I", "QcQUcQPcQm",
21022102
[ImmCheck<2, ImmCheck0_0>]>;
2103-
def VLUTI4_B_Q : SInst<"vluti4_laneq", "..UI", "QcQUcQPc",
2103+
def VLUTI4_B_Q : SInst<"vluti4_laneq", "..UI", "QcQUcQPcQm",
21042104
[ImmCheck<2, ImmCheck0_1>]>;
21052105
def VLUTI4_H_X2 : SInst<"vluti4_lane_x2", ".2(<qU)I", "QsQUsQPsQh",
21062106
[ImmCheck<3, ImmCheck0_1>]>;
@@ -2194,4 +2194,70 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
21942194
// fscale
21952195
def FSCALE_V128 : WInst<"vscale", "..(.S)", "QdQfQh">;
21962196
def FSCALE_V64 : WInst<"vscale", "(.q)(.q)(.qS)", "fh">;
2197+
}
2198+
2199+
//FP8 versions of untyped intrinsics
2200+
let ArchGuard = "defined(__aarch64__)" in {
2201+
def VGET_LANE_MF8 : IInst<"vget_lane", "1.I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2202+
def SPLAT_MF8 : WInst<"splat_lane", ".(!q)I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2203+
def SPLATQ_MF8 : WInst<"splat_laneq", ".(!Q)I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2204+
def VSET_LANE_MF8 : IInst<"vset_lane", ".1.I", "mQm", [ImmCheck<2, ImmCheckLaneIndex, 1>]>;
2205+
def VCREATE_MF8 : NoTestOpInst<"vcreate", ".(IU>)", "m", OP_CAST> { let BigEndianSafe = 1; }
2206+
let InstName = "vmov" in {
2207+
def VDUP_N_MF8 : WOpInst<"vdup_n", ".1", "mQm", OP_DUP>;
2208+
def VMOV_N_MF8 : WOpInst<"vmov_n", ".1", "mQm", OP_DUP>;
2209+
}
2210+
let InstName = "" in
2211+
def VDUP_LANE_MF8: WOpInst<"vdup_lane", ".qI", "mQm", OP_DUP_LN>;
2212+
def VCOMBINE_MF8 : NoTestOpInst<"vcombine", "Q..", "m", OP_CONC>;
2213+
let InstName = "vmov" in {
2214+
def VGET_HIGH_MF8 : NoTestOpInst<"vget_high", ".Q", "m", OP_HI>;
2215+
def VGET_LOW_MF8 : NoTestOpInst<"vget_low", ".Q", "m", OP_LO>;
2216+
}
2217+
let InstName = "vtbl" in {
2218+
def VTBL1_MF8 : WInst<"vtbl1", "..p", "m">;
2219+
def VTBL2_MF8 : WInst<"vtbl2", ".2p", "m">;
2220+
def VTBL3_MF8 : WInst<"vtbl3", ".3p", "m">;
2221+
def VTBL4_MF8 : WInst<"vtbl4", ".4p", "m">;
2222+
}
2223+
let InstName = "vtbx" in {
2224+
def VTBX1_MF8 : WInst<"vtbx1", "...p", "m">;
2225+
def VTBX2_MF8 : WInst<"vtbx2", "..2p", "m">;
2226+
def VTBX3_MF8 : WInst<"vtbx3", "..3p", "m">;
2227+
def VTBX4_MF8 : WInst<"vtbx4", "..4p", "m">;
2228+
}
2229+
def VEXT_MF8 : WInst<"vext", "...I", "mQm", [ImmCheck<2, ImmCheckLaneIndex, 0>]>;
2230+
def VREV64_MF8 : WOpInst<"vrev64", "..", "mQm", OP_REV64>;
2231+
def VREV32_MF8 : WOpInst<"vrev32", "..", "mQm", OP_REV32>;
2232+
def VREV16_MF8 : WOpInst<"vrev16", "..", "mQm", OP_REV16>;
2233+
let isHiddenLInst = 1 in
2234+
def VBSL_MF8 : SInst<"vbsl", ".U..", "mQm">;
2235+
def VTRN_MF8 : WInst<"vtrn", "2..", "mQm">;
2236+
def VZIP_MF8 : WInst<"vzip", "2..", "mQm">;
2237+
def VUZP_MF8 : WInst<"vuzp", "2..", "mQm">;
2238+
def COPY_LANE_MF8 : IOpInst<"vcopy_lane", "..I.I", "m", OP_COPY_LN>;
2239+
def COPYQ_LANE_MF8 : IOpInst<"vcopy_lane", "..IqI", "Qm", OP_COPY_LN>;
2240+
def COPY_LANEQ_MF8 : IOpInst<"vcopy_laneq", "..IQI", "m", OP_COPY_LN>;
2241+
def COPYQ_LANEQ_MF8 : IOpInst<"vcopy_laneq", "..I.I", "Qm", OP_COPY_LN>;
2242+
def VDUP_LANE2_MF8 : WOpInst<"vdup_laneq", ".QI", "mQm", OP_DUP_LN>;
2243+
def VTRN1_MF8 : SOpInst<"vtrn1", "...", "mQm", OP_TRN1>;
2244+
def VZIP1_MF8 : SOpInst<"vzip1", "...", "mQm", OP_ZIP1>;
2245+
def VUZP1_MF8 : SOpInst<"vuzp1", "...", "mQm", OP_UZP1>;
2246+
def VTRN2_MF8 : SOpInst<"vtrn2", "...", "mQm", OP_TRN2>;
2247+
def VZIP2_MF8 : SOpInst<"vzip2", "...", "mQm", OP_ZIP2>;
2248+
def VUZP2_MF8 : SOpInst<"vuzp2", "...", "mQm", OP_UZP2>;
2249+
let InstName = "vtbl" in {
2250+
def VQTBL1_A64_MF8 : WInst<"vqtbl1", ".QU", "mQm">;
2251+
def VQTBL2_A64_MF8 : WInst<"vqtbl2", ".(2Q)U", "mQm">;
2252+
def VQTBL3_A64_MF8 : WInst<"vqtbl3", ".(3Q)U", "mQm">;
2253+
def VQTBL4_A64_MF8 : WInst<"vqtbl4", ".(4Q)U", "mQm">;
2254+
}
2255+
let InstName = "vtbx" in {
2256+
def VQTBX1_A64_MF8 : WInst<"vqtbx1", "..QU", "mQm">;
2257+
def VQTBX2_A64_MF8 : WInst<"vqtbx2", "..(2Q)U", "mQm">;
2258+
def VQTBX3_A64_MF8 : WInst<"vqtbx3", "..(3Q)U", "mQm">;
2259+
def VQTBX4_A64_MF8 : WInst<"vqtbx4", "..(4Q)U", "mQm">;
2260+
}
2261+
def SCALAR_VDUP_LANE_MF8 : IInst<"vdup_lane", "1.I", "Sm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2262+
def SCALAR_VDUP_LANEQ_MF8 : IInst<"vdup_laneq", "1QI", "Sm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
21972263
}

clang/lib/AST/Type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,6 +2777,11 @@ static bool isTriviallyCopyableTypeImpl(const QualType &type,
27772777
if (CanonicalType->isScalarType() || CanonicalType->isVectorType())
27782778
return true;
27792779

2780+
// Mfloat8 type is a special case as it not scalar, but is still trivially
2781+
// copyable.
2782+
if (CanonicalType->isMFloat8Type())
2783+
return true;
2784+
27802785
if (const auto *RT = CanonicalType->getAs<RecordType>()) {
27812786
if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getDecl())) {
27822787
if (IsCopyConstructible) {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9082,22 +9082,26 @@ static bool HasExtraNeonArgument(unsigned BuiltinID) {
90829082
case NEON::BI__builtin_neon_vget_lane_bf16:
90839083
case NEON::BI__builtin_neon_vget_lane_i32:
90849084
case NEON::BI__builtin_neon_vget_lane_i64:
9085+
case NEON::BI__builtin_neon_vget_lane_mf8:
90859086
case NEON::BI__builtin_neon_vget_lane_f32:
90869087
case NEON::BI__builtin_neon_vgetq_lane_i8:
90879088
case NEON::BI__builtin_neon_vgetq_lane_i16:
90889089
case NEON::BI__builtin_neon_vgetq_lane_bf16:
90899090
case NEON::BI__builtin_neon_vgetq_lane_i32:
90909091
case NEON::BI__builtin_neon_vgetq_lane_i64:
9092+
case NEON::BI__builtin_neon_vgetq_lane_mf8:
90919093
case NEON::BI__builtin_neon_vgetq_lane_f32:
90929094
case NEON::BI__builtin_neon_vduph_lane_bf16:
90939095
case NEON::BI__builtin_neon_vduph_laneq_bf16:
90949096
case NEON::BI__builtin_neon_vset_lane_i8:
9097+
case NEON::BI__builtin_neon_vset_lane_mf8:
90959098
case NEON::BI__builtin_neon_vset_lane_i16:
90969099
case NEON::BI__builtin_neon_vset_lane_bf16:
90979100
case NEON::BI__builtin_neon_vset_lane_i32:
90989101
case NEON::BI__builtin_neon_vset_lane_i64:
90999102
case NEON::BI__builtin_neon_vset_lane_f32:
91009103
case NEON::BI__builtin_neon_vsetq_lane_i8:
9104+
case NEON::BI__builtin_neon_vsetq_lane_mf8:
91019105
case NEON::BI__builtin_neon_vsetq_lane_i16:
91029106
case NEON::BI__builtin_neon_vsetq_lane_bf16:
91039107
case NEON::BI__builtin_neon_vsetq_lane_i32:
@@ -12600,6 +12604,11 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1260012604
case NEON::BI__builtin_neon_vsetq_lane_f32:
1260112605
Ops.push_back(EmitScalarExpr(E->getArg(2)));
1260212606
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
12607+
case NEON::BI__builtin_neon_vset_lane_mf8:
12608+
case NEON::BI__builtin_neon_vsetq_lane_mf8:
12609+
Ops.push_back(EmitScalarExpr(E->getArg(2)));
12610+
Ops[0] = Builder.CreateExtractElement(Ops[0], Builder.getInt64(0));
12611+
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
1260312612
case NEON::BI__builtin_neon_vset_lane_f64:
1260412613
// The vector type needs a cast for the v1f64 variant.
1260512614
Ops[1] =
@@ -12625,6 +12634,12 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1262512634
Builder.CreateBitCast(Ops[0], llvm::FixedVectorType::get(Int8Ty, 16));
1262612635
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),
1262712636
"vgetq_lane");
12637+
case NEON::BI__builtin_neon_vget_lane_mf8:
12638+
case NEON::BI__builtin_neon_vdupb_lane_mf8:
12639+
case NEON::BI__builtin_neon_vgetq_lane_mf8:
12640+
case NEON::BI__builtin_neon_vdupb_laneq_mf8:
12641+
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),
12642+
"vget_lane");
1262812643
case NEON::BI__builtin_neon_vget_lane_i16:
1262912644
case NEON::BI__builtin_neon_vduph_lane_i16:
1263012645
Ops[0] =
@@ -14073,7 +14088,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1407314088
Int = Intrinsic::aarch64_neon_suqadd;
1407414089
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vuqadd");
1407514090
}
14076-
14091+
case NEON::BI__builtin_neon_vluti2_laneq_mf8:
1407714092
case NEON::BI__builtin_neon_vluti2_laneq_bf16:
1407814093
case NEON::BI__builtin_neon_vluti2_laneq_f16:
1407914094
case NEON::BI__builtin_neon_vluti2_laneq_p16:
@@ -14089,6 +14104,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1408914104
/*isQuad*/ false));
1409014105
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_laneq");
1409114106
}
14107+
case NEON::BI__builtin_neon_vluti2q_laneq_mf8:
1409214108
case NEON::BI__builtin_neon_vluti2q_laneq_bf16:
1409314109
case NEON::BI__builtin_neon_vluti2q_laneq_f16:
1409414110
case NEON::BI__builtin_neon_vluti2q_laneq_p16:
@@ -14104,6 +14120,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1410414120
/*isQuad*/ true));
1410514121
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_laneq");
1410614122
}
14123+
case NEON::BI__builtin_neon_vluti2_lane_mf8:
1410714124
case NEON::BI__builtin_neon_vluti2_lane_bf16:
1410814125
case NEON::BI__builtin_neon_vluti2_lane_f16:
1410914126
case NEON::BI__builtin_neon_vluti2_lane_p16:
@@ -14119,6 +14136,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1411914136
/*isQuad*/ false));
1412014137
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_lane");
1412114138
}
14139+
case NEON::BI__builtin_neon_vluti2q_lane_mf8:
1412214140
case NEON::BI__builtin_neon_vluti2q_lane_bf16:
1412314141
case NEON::BI__builtin_neon_vluti2q_lane_f16:
1412414142
case NEON::BI__builtin_neon_vluti2q_lane_p16:
@@ -14134,12 +14152,14 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1413414152
/*isQuad*/ true));
1413514153
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_lane");
1413614154
}
14155+
case NEON::BI__builtin_neon_vluti4q_lane_mf8:
1413714156
case NEON::BI__builtin_neon_vluti4q_lane_p8:
1413814157
case NEON::BI__builtin_neon_vluti4q_lane_s8:
1413914158
case NEON::BI__builtin_neon_vluti4q_lane_u8: {
1414014159
Int = Intrinsic::aarch64_neon_vluti4q_lane;
1414114160
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vluti4q_lane");
1414214161
}
14162+
case NEON::BI__builtin_neon_vluti4q_laneq_mf8:
1414314163
case NEON::BI__builtin_neon_vluti4q_laneq_p8:
1414414164
case NEON::BI__builtin_neon_vluti4q_laneq_s8:
1414514165
case NEON::BI__builtin_neon_vluti4q_laneq_u8: {

clang/lib/Sema/SemaInit.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,6 +1940,8 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
19401940
typeCode = "s";
19411941
else if (elementType->isUnsignedIntegerType())
19421942
typeCode = "u";
1943+
else if (elementType->isMFloat8Type())
1944+
typeCode = "mf";
19431945
else
19441946
llvm_unreachable("Invalid element type!");
19451947

0 commit comments

Comments
 (0)