Skip to content

Commit 1cde192

Browse files
[AArch64] Implement NEON FP8 intrinsics for fused multiply-add (indexed)
This patch adds the following intrinsics: * Floating-point multiply-add long to half-precision (vector, by element) float16x8_t vmlalbq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlalbq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vmlaltq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) * Floating-point multiply-add long-long to single-precision (vector, by element) float32x4_t vmlallbbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallbtq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlalltbq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vmlallttq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
1 parent 660fdce commit 1cde192

File tree

9 files changed

+429
-24
lines changed

9 files changed

+429
-24
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,6 +2179,20 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8fma,neon" in {
21792179
def VMLALLBT_F32_F8 : VInst<"vmlallbtq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21802180
def VMLALLTB_F32_F8 : VInst<"vmlalltbq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
21812181
def VMLALLTT_F32_F8 : VInst<"vmlallttq_f32_mf8_fpm", "(>>F)(>>F)..V", "Qm">;
2182+
2183+
def VMLALB_F16_F8_LANE : VInst<"vmlalbq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2184+
def VMLALB_F16_F8_LANEQ : VInst<"vmlalbq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2185+
def VMLALT_F16_F8_LANE : VInst<"vmlaltq_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2186+
def VMLALT_F16_F8_LANEQ : VInst<"vmlaltq_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2187+
2188+
def VMLALLBB_F32_F8_LANE : VInst<"vmlallbbq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2189+
def VMLALLBB_F32_F8_LANEQ : VInst<"vmlallbbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2190+
def VMLALLBT_F32_F8_LANE : VInst<"vmlallbtq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2191+
def VMLALLBT_F32_F8_LANEQ : VInst<"vmlallbtq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2192+
def VMLALLTB_F32_F8_LANE : VInst<"vmlalltbq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2193+
def VMLALLTB_F32_F8_LANEQ : VInst<"vmlalltbq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
2194+
def VMLALLTT_F32_F8_LANE : VInst<"vmlallttq_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2195+
def VMLALLTT_F32_F8_LANEQ : VInst<"vmlallttq_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_15, 0>]>;
21822196
}
21832197

21842198
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6880,21 +6880,36 @@ Value *CodeGenFunction::EmitFP8NeonCall(unsigned IID,
68806880
}
68816881

68826882
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6883-
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6883+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
68846884
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
68856885

68866886
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
68876887
RetTy->getPrimitiveSizeInBits();
68886888
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
68896889
Ops[1]->getType()};
6890-
if (ExtendLane) {
6890+
if (ExtendLaneArg) {
68916891
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
68926892
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
68936893
Builder.getInt64(0));
68946894
}
68956895
return EmitFP8NeonCall(IID, Tys, Ops, E, name);
68966896
}
68976897

6898+
llvm::Value *CodeGenFunction::EmitFP8NeonFMLACall(
6899+
unsigned IID, bool ExtendLaneArg, llvm::Type *RetTy,
6900+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6901+
6902+
if (ExtendLaneArg) {
6903+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6904+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6905+
Builder.getInt64(0));
6906+
}
6907+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6908+
RetTy->getPrimitiveSizeInBits();
6909+
return EmitFP8NeonCall(IID, {llvm::FixedVectorType::get(RetTy, ElemCount)},
6910+
Ops, E, name);
6911+
}
6912+
68986913
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
68996914
bool neg) {
69006915
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12840,7 +12855,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1284012855

1284112856
unsigned Int;
1284212857
bool ExtractLow = false;
12843-
bool ExtendLane = false;
12858+
bool ExtendLaneArg = false;
1284412859
switch (BuiltinID) {
1284512860
default: return nullptr;
1284612861
case NEON::BI__builtin_neon_vbsl_v:
@@ -14115,24 +14130,24 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1411514130
Ops, E, "fdot2");
1411614131
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
1411714132
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14118-
ExtendLane = true;
14133+
ExtendLaneArg = true;
1411914134
LLVM_FALLTHROUGH;
1412014135
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
1412114136
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
1412214137
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14123-
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14138+
ExtendLaneArg, HalfTy, Ops, E, "fdot2_lane");
1412414139
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
1412514140
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
1412614141
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
1412714142
FloatTy, Ops, E, "fdot4");
1412814143
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
1412914144
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14130-
ExtendLane = true;
14145+
ExtendLaneArg = true;
1413114146
LLVM_FALLTHROUGH;
1413214147
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
1413314148
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
1413414149
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14135-
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
14150+
ExtendLaneArg, FloatTy, Ops, E, "fdot4_lane");
1413614151

1413714152
case NEON::BI__builtin_neon_vmlalbq_f16_mf8_fpm:
1413814153
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalb,
@@ -14158,7 +14173,42 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1415814173
return EmitFP8NeonCall(Intrinsic::aarch64_neon_fp8_fmlalltt,
1415914174
{llvm::FixedVectorType::get(FloatTy, 4)}, Ops, E,
1416014175
"vmlall");
14161-
14176+
case NEON::BI__builtin_neon_vmlalbq_lane_f16_mf8_fpm:
14177+
ExtendLaneArg = true;
14178+
LLVM_FALLTHROUGH;
14179+
case NEON::BI__builtin_neon_vmlalbq_laneq_f16_mf8_fpm:
14180+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalb_lane,
14181+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14182+
case NEON::BI__builtin_neon_vmlaltq_lane_f16_mf8_fpm:
14183+
ExtendLaneArg = true;
14184+
LLVM_FALLTHROUGH;
14185+
case NEON::BI__builtin_neon_vmlaltq_laneq_f16_mf8_fpm:
14186+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalt_lane,
14187+
ExtendLaneArg, HalfTy, Ops, E, "vmlal_lane");
14188+
case NEON::BI__builtin_neon_vmlallbbq_lane_f32_mf8_fpm:
14189+
ExtendLaneArg = true;
14190+
LLVM_FALLTHROUGH;
14191+
case NEON::BI__builtin_neon_vmlallbbq_laneq_f32_mf8_fpm:
14192+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbb_lane,
14193+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14194+
case NEON::BI__builtin_neon_vmlallbtq_lane_f32_mf8_fpm:
14195+
ExtendLaneArg = true;
14196+
LLVM_FALLTHROUGH;
14197+
case NEON::BI__builtin_neon_vmlallbtq_laneq_f32_mf8_fpm:
14198+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlallbt_lane,
14199+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14200+
case NEON::BI__builtin_neon_vmlalltbq_lane_f32_mf8_fpm:
14201+
ExtendLaneArg = true;
14202+
LLVM_FALLTHROUGH;
14203+
case NEON::BI__builtin_neon_vmlalltbq_laneq_f32_mf8_fpm:
14204+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltb_lane,
14205+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
14206+
case NEON::BI__builtin_neon_vmlallttq_lane_f32_mf8_fpm:
14207+
ExtendLaneArg = true;
14208+
LLVM_FALLTHROUGH;
14209+
case NEON::BI__builtin_neon_vmlallttq_laneq_f32_mf8_fpm:
14210+
return EmitFP8NeonFMLACall(Intrinsic::aarch64_neon_fp8_fmlalltt_lane,
14211+
ExtendLaneArg, FloatTy, Ops, E, "vmlall_lane");
1416214212
case NEON::BI__builtin_neon_vamin_f16:
1416314213
case NEON::BI__builtin_neon_vaminq_f16:
1416414214
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4655,7 +4655,11 @@ class CodeGenFunction : public CodeGenTypeCache {
46554655
llvm::Type *Ty1, bool Extract,
46564656
SmallVectorImpl<llvm::Value *> &Ops,
46574657
const CallExpr *E, const char *name);
4658-
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4658+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLaneArg,
4659+
llvm::Type *RetTy,
4660+
SmallVectorImpl<llvm::Value *> &Ops,
4661+
const CallExpr *E, const char *name);
4662+
llvm::Value *EmitFP8NeonFMLACall(unsigned IID, bool ExtendLaneArg,
46594663
llvm::Type *RetTy,
46604664
SmallVectorImpl<llvm::Value *> &Ops,
46614665
const CallExpr *E, const char *name);

0 commit comments

Comments
 (0)