Skip to content

Commit 2eed88d

Browse files
[AArch64] Implement FP8 SVE intrinsics for fused multiply-add (#118126)
This patch adds the following intrinsics: * 8-bit floating-point multiply-add long to half-precision (bottom). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat16_t svmlalb[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat16_t svmlalb[_n_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long to half-precision (bottom, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat16_t svmlalb_lane[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm); * 8-bit floating-point multiply-add long to half-precision (top). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat16_t svmlalt[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat16_t svmlalt[_n_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long to half-precision (top, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat16_t svmlalt_lane[_f16_mf8]_fpm(svfloat16_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (bottom bottom). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlallbb[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat32_t svmlallbb[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (bottom bottom, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlallbb_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (bottom top). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlallbt[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat32_t svmlallbt[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (bottom top, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlallbt_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (top bottom). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlalltb[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat32_t svmlalltb[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (top bottom, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlalltb_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (top top). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlalltt[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm); svfloat32_t svmlalltt[_n_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, mfloat8_t zm, fpm_t fpm); * 8-bit floating-point multiply-add long long to single-precision (top top, indexed). // Only if (__ARM_FEATURE_SVE2 && __ARM_FEATURE_FP8FMA) || __ARM_FEATURE_SSVE_FP8FMA svfloat32_t svmlalltt_lane[_f32_mf8]_fpm(svfloat32_t zda, svmfloat8_t zn, svmfloat8_t zm, uint64_t imm0_15, fpm_t fpm);
1 parent 2244d2e commit 2eed88d

File tree

7 files changed

+636
-16
lines changed

7 files changed

+636
-16
lines changed

clang/include/clang/Basic/arm_sve.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,3 +2495,34 @@ let SVETargetGuard = "sve2,fp8dot4", SMETargetGuard ="sme,ssve-fp8dot4" in {
24952495
def SVFDOT_LANE_4WAY : SInst<"svdot_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fdot_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_3>]>;
24962496
}
24972497

2498+
let SVETargetGuard = "sve2,fp8fma", SMETargetGuard = "sme,ssve-fp8fma" in {
2499+
// 8-bit floating-point multiply-add long to half-precision (bottom)
2500+
def SVFMLALB : SInst<"svmlalb[_f16_mf8]_fpm", "dd~~>", "h", MergeNone, "aarch64_sve_fp8_fmlalb", [VerifyRuntimeMode, SetsFPMR]>;
2501+
def SVFMLALB_N : SInst<"svmlalb[_n_f16_mf8]_fpm", "dd~!>", "h", MergeNone, "aarch64_sve_fp8_fmlalb", [VerifyRuntimeMode, SetsFPMR]>;
2502+
2503+
// 8-bit floating-point multiply-add long to ha_fpmlf-precision (bottom, indexed)
2504+
def SVFMLALB_LANE : SInst<"svmlalb_lane[_f16_mf8]_fpm", "dd~~i>", "h", MergeNone, "aarch64_sve_fp8_fmlalb_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_15>]>;
2505+
2506+
// 8-bit floating-point multiply-add long to half-precision (top)
2507+
def SVFMLALT : SInst<"svmlalt[_f16_mf8]_fpm", "dd~~>", "h", MergeNone, "aarch64_sve_fp8_fmlalt", [VerifyRuntimeMode, SetsFPMR]>;
2508+
def SVFMLALT_N : SInst<"svmlalt[_n_f16_mf8]_fpm", "dd~!>", "h", MergeNone, "aarch64_sve_fp8_fmlalt", [VerifyRuntimeMode, SetsFPMR]>;
2509+
2510+
// 8-bit floating-point multiply-add long to half-precision (top, indexed)
2511+
def SVFMLALT_LANE : SInst<"svmlalt_lane[_f16_mf8]_fpm", "dd~~i>", "h", MergeNone, "aarch64_sve_fp8_fmlalt_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_15>]>;
2512+
2513+
// 8-bit floating-point multiply-add long long to single-precision (all top/bottom variants)
2514+
def SVFMLALLBB : SInst<"svmlallbb[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fmlallbb", [VerifyRuntimeMode, SetsFPMR]>;
2515+
def SVFMLALLBB_N : SInst<"svmlallbb[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fmlallbb", [VerifyRuntimeMode, SetsFPMR]>;
2516+
def SVFMLALLBT : SInst<"svmlallbt[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fmlallbt", [VerifyRuntimeMode, SetsFPMR]>;
2517+
def SVFMLALLBT_N : SInst<"svmlallbt[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fmlallbt", [VerifyRuntimeMode, SetsFPMR]>;
2518+
def SVFMLALLTB : SInst<"svmlalltb[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fmlalltb", [VerifyRuntimeMode, SetsFPMR]>;
2519+
def SVFMLALLTB_N : SInst<"svmlalltb[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fmlalltb", [VerifyRuntimeMode, SetsFPMR]>;
2520+
def SVFMLALLTT : SInst<"svmlalltt[_f32_mf8]_fpm", "dd~~>", "f", MergeNone, "aarch64_sve_fp8_fmlalltt", [VerifyRuntimeMode, SetsFPMR]>;
2521+
def SVFMLALLTT_N : SInst<"svmlalltt[_n_f32_mf8]_fpm", "dd~!>", "f", MergeNone, "aarch64_sve_fp8_fmlalltt", [VerifyRuntimeMode, SetsFPMR]>;
2522+
2523+
// 8-bit floating-point multiply-add long long to single-precision (indexed, all top/bottom variants)
2524+
def SVFMLALLBB_LANE : SInst<"svmlallbb_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fmlallbb_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
2525+
def SVFMLALLBT_LANE : SInst<"svmlallbt_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fmlallbt_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
2526+
def SVFMLALLTB_LANE : SInst<"svmlalltb_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fmlalltb_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
2527+
def SVFMLALLTT_LANE : SInst<"svmlalltt_lane[_f32_mf8]_fpm", "dd~~i>", "f", MergeNone, "aarch64_sve_fp8_fmlalltt_lane", [VerifyRuntimeMode, SetsFPMR], [ImmCheck<3, ImmCheck0_7>]>;
2528+
}

0 commit comments

Comments
 (0)