Skip to content

Commit c331c4c

Browse files
committed
[Clang][AArch64] Add fp8 variants for untyped NEON intrinsics
This patch adds fp8 variants to existing intrinsics, whose operation doesn't depend on arguments being a specific type.
1 parent 290d7b8 commit c331c4c

File tree

6 files changed

+1220
-4
lines changed

6 files changed

+1220
-4
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,17 +2090,17 @@ let ArchGuard = "defined(__aarch64__) || defined(__arm64ec__)", TargetGuard = "r
20902090

20912091
// Lookup table read with 2-bit/4-bit indices
20922092
let ArchGuard = "defined(__aarch64__)", TargetGuard = "lut" in {
2093-
def VLUTI2_B : SInst<"vluti2_lane", "Q.(qU)I", "cUcPcQcQUcQPc",
2093+
def VLUTI2_B : SInst<"vluti2_lane", "Q.(qU)I", "cUcPcmQcQUcQPcQm",
20942094
[ImmCheck<2, ImmCheck0_1>]>;
2095-
def VLUTI2_B_Q : SInst<"vluti2_laneq", "Q.(QU)I", "cUcPcQcQUcQPc",
2095+
def VLUTI2_B_Q : SInst<"vluti2_laneq", "Q.(QU)I", "cUcPcmQcQUcQPcQm",
20962096
[ImmCheck<2, ImmCheck0_3>]>;
20972097
def VLUTI2_H : SInst<"vluti2_lane", "Q.(<qU)I", "sUsPshQsQUsQPsQh",
20982098
[ImmCheck<2, ImmCheck0_3>]>;
20992099
def VLUTI2_H_Q : SInst<"vluti2_laneq", "Q.(<QU)I", "sUsPshQsQUsQPsQh",
21002100
[ImmCheck<2, ImmCheck0_7>]>;
2101-
def VLUTI4_B : SInst<"vluti4_lane", "..(qU)I", "QcQUcQPc",
2101+
def VLUTI4_B : SInst<"vluti4_lane", "..(qU)I", "QcQUcQPcQm",
21022102
[ImmCheck<2, ImmCheck0_0>]>;
2103-
def VLUTI4_B_Q : SInst<"vluti4_laneq", "..UI", "QcQUcQPc",
2103+
def VLUTI4_B_Q : SInst<"vluti4_laneq", "..UI", "QcQUcQPcQm",
21042104
[ImmCheck<2, ImmCheck0_1>]>;
21052105
def VLUTI4_H_X2 : SInst<"vluti4_lane_x2", ".2(<qU)I", "QsQUsQPsQh",
21062106
[ImmCheck<3, ImmCheck0_1>]>;
@@ -2194,4 +2194,70 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
21942194
// fscale
21952195
def FSCALE_V128 : WInst<"vscale", "..(.S)", "QdQfQh">;
21962196
def FSCALE_V64 : WInst<"vscale", "(.q)(.q)(.qS)", "fh">;
2197+
}
2198+
2199+
//FP8 versions of untyped intrinsics
2200+
let ArchGuard = "defined(__aarch64__)" in {
2201+
def VGET_LANE_MF8 : IInst<"vget_lane", "1.I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2202+
def SPLAT_MF8 : WInst<"splat_lane", ".(!q)I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2203+
def SPLATQ_MF8 : WInst<"splat_laneq", ".(!Q)I", "mQm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2204+
def VSET_LANE_MF8 : IInst<"vset_lane", ".1.I", "mQm", [ImmCheck<2, ImmCheckLaneIndex, 1>]>;
2205+
def VCREATE_MF8 : NoTestOpInst<"vcreate", ".(IU>)", "m", OP_CAST> { let BigEndianSafe = 1; }
2206+
let InstName = "vmov" in {
2207+
def VDUP_N_MF8 : WOpInst<"vdup_n", ".1", "mQm", OP_DUP>;
2208+
def VMOV_N_MF8 : WOpInst<"vmov_n", ".1", "mQm", OP_DUP>;
2209+
}
2210+
let InstName = "" in
2211+
def VDUP_LANE_MF8: WOpInst<"vdup_lane", ".qI", "mQm", OP_DUP_LN>;
2212+
def VCOMBINE_MF8 : NoTestOpInst<"vcombine", "Q..", "m", OP_CONC>;
2213+
let InstName = "vmov" in {
2214+
def VGET_HIGH_MF8 : NoTestOpInst<"vget_high", ".Q", "m", OP_HI>;
2215+
def VGET_LOW_MF8 : NoTestOpInst<"vget_low", ".Q", "m", OP_LO>;
2216+
}
2217+
let InstName = "vtbl" in {
2218+
def VTBL1_MF8 : WInst<"vtbl1", "..p", "m">;
2219+
def VTBL2_MF8 : WInst<"vtbl2", ".2p", "m">;
2220+
def VTBL3_MF8 : WInst<"vtbl3", ".3p", "m">;
2221+
def VTBL4_MF8 : WInst<"vtbl4", ".4p", "m">;
2222+
}
2223+
let InstName = "vtbx" in {
2224+
def VTBX1_MF8 : WInst<"vtbx1", "...p", "m">;
2225+
def VTBX2_MF8 : WInst<"vtbx2", "..2p", "m">;
2226+
def VTBX3_MF8 : WInst<"vtbx3", "..3p", "m">;
2227+
def VTBX4_MF8 : WInst<"vtbx4", "..4p", "m">;
2228+
}
2229+
def VEXT_MF8 : WInst<"vext", "...I", "mQm", [ImmCheck<2, ImmCheckLaneIndex, 0>]>;
2230+
def VREV64_MF8 : WOpInst<"vrev64", "..", "mQm", OP_REV64>;
2231+
def VREV32_MF8 : WOpInst<"vrev32", "..", "mQm", OP_REV32>;
2232+
def VREV16_MF8 : WOpInst<"vrev16", "..", "mQm", OP_REV16>;
2233+
let isHiddenLInst = 1 in
2234+
def VBSL_MF8 : SInst<"vbsl", ".U..", "mQm">;
2235+
def VTRN_MF8 : WInst<"vtrn", "2..", "mQm">;
2236+
def VZIP_MF8 : WInst<"vzip", "2..", "mQm">;
2237+
def VUZP_MF8 : WInst<"vuzp", "2..", "mQm">;
2238+
def COPY_LANE_MF8 : IOpInst<"vcopy_lane", "..I.I", "m", OP_COPY_LN>;
2239+
def COPYQ_LANE_MF8 : IOpInst<"vcopy_lane", "..IqI", "Qm", OP_COPY_LN>;
2240+
def COPY_LANEQ_MF8 : IOpInst<"vcopy_laneq", "..IQI", "m", OP_COPY_LN>;
2241+
def COPYQ_LANEQ_MF8 : IOpInst<"vcopy_laneq", "..I.I", "Qm", OP_COPY_LN>;
2242+
def VDUP_LANE2_MF8 : WOpInst<"vdup_laneq", ".QI", "mQm", OP_DUP_LN>;
2243+
def VTRN1_MF8 : SOpInst<"vtrn1", "...", "mQm", OP_TRN1>;
2244+
def VZIP1_MF8 : SOpInst<"vzip1", "...", "mQm", OP_ZIP1>;
2245+
def VUZP1_MF8 : SOpInst<"vuzp1", "...", "mQm", OP_UZP1>;
2246+
def VTRN2_MF8 : SOpInst<"vtrn2", "...", "mQm", OP_TRN2>;
2247+
def VZIP2_MF8 : SOpInst<"vzip2", "...", "mQm", OP_ZIP2>;
2248+
def VUZP2_MF8 : SOpInst<"vuzp2", "...", "mQm", OP_UZP2>;
2249+
let InstName = "vtbl" in {
2250+
def VQTBL1_A64_MF8 : WInst<"vqtbl1", ".QU", "mQm">;
2251+
def VQTBL2_A64_MF8 : WInst<"vqtbl2", ".(2Q)U", "mQm">;
2252+
def VQTBL3_A64_MF8 : WInst<"vqtbl3", ".(3Q)U", "mQm">;
2253+
def VQTBL4_A64_MF8 : WInst<"vqtbl4", ".(4Q)U", "mQm">;
2254+
}
2255+
let InstName = "vtbx" in {
2256+
def VQTBX1_A64_MF8 : WInst<"vqtbx1", "..QU", "mQm">;
2257+
def VQTBX2_A64_MF8 : WInst<"vqtbx2", "..(2Q)U", "mQm">;
2258+
def VQTBX3_A64_MF8 : WInst<"vqtbx3", "..(3Q)U", "mQm">;
2259+
def VQTBX4_A64_MF8 : WInst<"vqtbx4", "..(4Q)U", "mQm">;
2260+
}
2261+
def SCALAR_VDUP_LANE_MF8 : IInst<"vdup_lane", "1.I", "Sm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
2262+
def SCALAR_VDUP_LANEQ_MF8 : IInst<"vdup_laneq", "1QI", "Sm", [ImmCheck<1, ImmCheckLaneIndex, 0>]>;
21972263
}

clang/lib/AST/Type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2782,6 +2782,11 @@ static bool isTriviallyCopyableTypeImpl(const QualType &type,
27822782
if (CanonicalType->isScalarType() || CanonicalType->isVectorType())
27832783
return true;
27842784

2785+
// Mfloat8 type is a special case as it not scalar, but is still trivially
2786+
// copyable.
2787+
if (CanonicalType->isMFloat8Type())
2788+
return true;
2789+
27852790
if (const auto *RT = CanonicalType->getAs<RecordType>()) {
27862791
if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getDecl())) {
27872792
if (IsCopyConstructible) {

clang/lib/CodeGen/CGCall.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,6 +5464,15 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
54645464
Builder.CreateStore(errorValue, swiftErrorTemp);
54655465
}
54665466

5467+
// Mfloat8 type is loaded as scalar type, but is treated as single
5468+
// vector type for other operations. We need to bitcast it to the vector
5469+
// type here.
5470+
if (auto *EltTy =
5471+
dyn_cast<llvm::FixedVectorType>(ArgInfo.getCoerceToType());
5472+
EltTy && EltTy->getNumElements() == 1 &&
5473+
V->getType() == EltTy->getScalarType())
5474+
V = Builder.CreateBitCast(V, EltTy);
5475+
54675476
// We might have to widen integers, but we should never truncate.
54685477
if (ArgInfo.getCoerceToType() != V->getType() &&
54695478
V->getType()->isIntegerTy())

clang/lib/CodeGen/TargetBuiltins/ARM.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2623,22 +2623,26 @@ static bool HasExtraNeonArgument(unsigned BuiltinID) {
26232623
case NEON::BI__builtin_neon_vget_lane_bf16:
26242624
case NEON::BI__builtin_neon_vget_lane_i32:
26252625
case NEON::BI__builtin_neon_vget_lane_i64:
2626+
case NEON::BI__builtin_neon_vget_lane_mf8:
26262627
case NEON::BI__builtin_neon_vget_lane_f32:
26272628
case NEON::BI__builtin_neon_vgetq_lane_i8:
26282629
case NEON::BI__builtin_neon_vgetq_lane_i16:
26292630
case NEON::BI__builtin_neon_vgetq_lane_bf16:
26302631
case NEON::BI__builtin_neon_vgetq_lane_i32:
26312632
case NEON::BI__builtin_neon_vgetq_lane_i64:
2633+
case NEON::BI__builtin_neon_vgetq_lane_mf8:
26322634
case NEON::BI__builtin_neon_vgetq_lane_f32:
26332635
case NEON::BI__builtin_neon_vduph_lane_bf16:
26342636
case NEON::BI__builtin_neon_vduph_laneq_bf16:
26352637
case NEON::BI__builtin_neon_vset_lane_i8:
2638+
case NEON::BI__builtin_neon_vset_lane_mf8:
26362639
case NEON::BI__builtin_neon_vset_lane_i16:
26372640
case NEON::BI__builtin_neon_vset_lane_bf16:
26382641
case NEON::BI__builtin_neon_vset_lane_i32:
26392642
case NEON::BI__builtin_neon_vset_lane_i64:
26402643
case NEON::BI__builtin_neon_vset_lane_f32:
26412644
case NEON::BI__builtin_neon_vsetq_lane_i8:
2645+
case NEON::BI__builtin_neon_vsetq_lane_mf8:
26422646
case NEON::BI__builtin_neon_vsetq_lane_i16:
26432647
case NEON::BI__builtin_neon_vsetq_lane_bf16:
26442648
case NEON::BI__builtin_neon_vsetq_lane_i32:
@@ -6161,6 +6165,10 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
61616165
Builder.CreateBitCast(Ops[1], llvm::FixedVectorType::get(DoubleTy, 1));
61626166
Ops.push_back(EmitScalarExpr(E->getArg(2)));
61636167
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
6168+
case NEON::BI__builtin_neon_vset_lane_mf8:
6169+
case NEON::BI__builtin_neon_vsetq_lane_mf8:
6170+
Ops.push_back(EmitScalarExpr(E->getArg(2)));
6171+
return Builder.CreateInsertElement(Ops[1], Ops[0], Ops[2], "vset_lane");
61646172
case NEON::BI__builtin_neon_vsetq_lane_f64:
61656173
// The vector type needs a cast for the v2f64 variant.
61666174
Ops[1] =
@@ -6180,6 +6188,12 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
61806188
Builder.CreateBitCast(Ops[0], llvm::FixedVectorType::get(Int8Ty, 16));
61816189
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),
61826190
"vgetq_lane");
6191+
case NEON::BI__builtin_neon_vget_lane_mf8:
6192+
case NEON::BI__builtin_neon_vdupb_lane_mf8:
6193+
case NEON::BI__builtin_neon_vgetq_lane_mf8:
6194+
case NEON::BI__builtin_neon_vdupb_laneq_mf8:
6195+
return Builder.CreateExtractElement(Ops[0], EmitScalarExpr(E->getArg(1)),
6196+
"vget_lane");
61836197
case NEON::BI__builtin_neon_vget_lane_i16:
61846198
case NEON::BI__builtin_neon_vduph_lane_i16:
61856199
Ops[0] =
@@ -7629,6 +7643,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
76297643
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vuqadd");
76307644
}
76317645

7646+
case NEON::BI__builtin_neon_vluti2_laneq_mf8:
76327647
case NEON::BI__builtin_neon_vluti2_laneq_bf16:
76337648
case NEON::BI__builtin_neon_vluti2_laneq_f16:
76347649
case NEON::BI__builtin_neon_vluti2_laneq_p16:
@@ -7644,6 +7659,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
76447659
/*isQuad*/ false));
76457660
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_laneq");
76467661
}
7662+
case NEON::BI__builtin_neon_vluti2q_laneq_mf8:
76477663
case NEON::BI__builtin_neon_vluti2q_laneq_bf16:
76487664
case NEON::BI__builtin_neon_vluti2q_laneq_f16:
76497665
case NEON::BI__builtin_neon_vluti2q_laneq_p16:
@@ -7659,6 +7675,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
76597675
/*isQuad*/ true));
76607676
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_laneq");
76617677
}
7678+
case NEON::BI__builtin_neon_vluti2_lane_mf8:
76627679
case NEON::BI__builtin_neon_vluti2_lane_bf16:
76637680
case NEON::BI__builtin_neon_vluti2_lane_f16:
76647681
case NEON::BI__builtin_neon_vluti2_lane_p16:
@@ -7674,6 +7691,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
76747691
/*isQuad*/ false));
76757692
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_lane");
76767693
}
7694+
case NEON::BI__builtin_neon_vluti2q_lane_mf8:
76777695
case NEON::BI__builtin_neon_vluti2q_lane_bf16:
76787696
case NEON::BI__builtin_neon_vluti2q_lane_f16:
76797697
case NEON::BI__builtin_neon_vluti2q_lane_p16:
@@ -7689,12 +7707,14 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
76897707
/*isQuad*/ true));
76907708
return EmitNeonCall(CGM.getIntrinsic(Int, Tys), Ops, "vluti2_lane");
76917709
}
7710+
case NEON::BI__builtin_neon_vluti4q_lane_mf8:
76927711
case NEON::BI__builtin_neon_vluti4q_lane_p8:
76937712
case NEON::BI__builtin_neon_vluti4q_lane_s8:
76947713
case NEON::BI__builtin_neon_vluti4q_lane_u8: {
76957714
Int = Intrinsic::aarch64_neon_vluti4q_lane;
76967715
return EmitNeonCall(CGM.getIntrinsic(Int, Ty), Ops, "vluti4q_lane");
76977716
}
7717+
case NEON::BI__builtin_neon_vluti4q_laneq_mf8:
76987718
case NEON::BI__builtin_neon_vluti4q_laneq_p8:
76997719
case NEON::BI__builtin_neon_vluti4q_laneq_s8:
77007720
case NEON::BI__builtin_neon_vluti4q_laneq_u8: {

clang/lib/Sema/SemaInit.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,8 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
19441944
typeCode = "s";
19451945
else if (elementType->isUnsignedIntegerType())
19461946
typeCode = "u";
1947+
else if (elementType->isMFloat8Type())
1948+
typeCode = "mf";
19471949
else
19481950
llvm_unreachable("Invalid element type!");
19491951

0 commit comments

Comments
 (0)