Skip to content

Commit 804b81d

Browse files
[AArch64] Add FP8 Neon intrinsics for dot-product (llvm#123613)
This patch adds the following intrinsics: float16x4_t vdot_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpm) float16x8_t vdotq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) float16x4_t vdot_lane_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x4_t vdot_laneq_f16_mf8_fpm(float16x4_t vd, mfloat8x8_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vdotq_lane_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float16x8_t vdotq_laneq_f16_mf8_fpm(float16x8_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x2_t vdot_f32_mf8_fpm(float32x2_t vd, mfloat8x8_t vn, mfloat8x8_t vm, fpm_t fpm) float32x4_t vdotq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, fpm_t fpm) float32x2_t vdot_lane_f32_mf8_fpm(float32x2_t vd, mfloat8x8_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x2_t vdot_laneq_f32_mf8_fpm(float32x2_t vd, mfloat8x8_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vdotq_lane_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x8_t vm, __builtin_constant_p(lane), fpm_t fpm) float32x4_t vdotq_laneq_f32_mf8_fpm(float32x4_t vd, mfloat8x16_t vn, mfloat8x16_t vm, __builtin_constant_p(lane), fpm_t fpm)
1 parent 3bf8e67 commit 804b81d

File tree

10 files changed

+529
-40
lines changed

10 files changed

+529
-40
lines changed

clang/include/clang/Basic/arm_neon.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,6 +2141,26 @@ let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8,neon" in {
21412141
def VCVTN_F8_F16 : VInst<"vcvt_mf8_f16_fpm", ".(>F)(>F)V", "mQm">;
21422142
}
21432143

2144+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot2,neon" in {
2145+
def VDOT_F16_MF8 : VInst<"vdot_f16_mf8_fpm", "(>F)(>F)..V", "mQm">;
2146+
2147+
def VDOT_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F)..IV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
2148+
def VDOT_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F).QIV", "m", [ImmCheck<3, ImmCheck0_7, 0>]>;
2149+
2150+
def VDOTQ_LANE_F16_MF8 : VInst<"vdot_lane_f16_mf8_fpm", "(>F)(>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
2151+
def VDOTQ_LANEQ_F16_MF8 : VInst<"vdot_laneq_f16_mf8_fpm", "(>F)(>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_7, 0>]>;
2152+
}
2153+
2154+
let ArchGuard = "defined(__aarch64__)", TargetGuard = "fp8dot4,neon" in {
2155+
def VDOT_F32_MF8 : VInst<"vdot_f32_mf8_fpm", "(>>F)(>>F)..V", "mQm">;
2156+
2157+
def VDOT_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F)..IV", "m", [ImmCheck<3, ImmCheck0_1, 0>]>;
2158+
def VDOT_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F).QIV", "m", [ImmCheck<3, ImmCheck0_3, 0>]>;
2159+
2160+
def VDOTQ_LANE_F32_MF8 : VInst<"vdot_lane_f32_mf8_fpm", "(>>F)(>>F).qIV", "Qm", [ImmCheck<3, ImmCheck0_1, 0>]>;
2161+
def VDOTQ_LANEQ_F32_MF8 : VInst<"vdot_laneq_f32_mf8_fpm", "(>>F)(>>F)..IV", "Qm", [ImmCheck<3, ImmCheck0_3, 0>]>;
2162+
}
2163+
21442164
let ArchGuard = "defined(__aarch64__)", TargetGuard = "neon,faminmax" in {
21452165
def FAMIN : WInst<"vamin", "...", "fhQdQfQh">;
21462166
def FAMAX : WInst<"vamax", "...", "fhQdQfQh">;

clang/include/clang/Basic/arm_neon_incl.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ class Inst <string n, string p, string t, Operation o, list<ImmCheck> ch = []>{
302302
class SInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
303303
class IInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
304304
class WInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
305-
class VInst<string n, string p, string t> : Inst<n, p, t, OP_NONE> {}
305+
class VInst<string n, string p, string t, list<ImmCheck> ch = []> : Inst<n, p, t, OP_NONE, ch> {}
306306

307307
// The following instruction classes are implemented via operators
308308
// instead of builtins. As such these declarations are only used for

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6766,6 +6766,24 @@ Value *CodeGenFunction::EmitFP8NeonCall(Function *F,
67666766
return EmitNeonCall(F, Ops, name);
67676767
}
67686768

6769+
llvm::Value *CodeGenFunction::EmitFP8NeonFDOTCall(
6770+
unsigned IID, bool ExtendLane, llvm::Type *RetTy,
6771+
SmallVectorImpl<llvm::Value *> &Ops, const CallExpr *E, const char *name) {
6772+
6773+
const unsigned ElemCount = Ops[0]->getType()->getPrimitiveSizeInBits() /
6774+
RetTy->getPrimitiveSizeInBits();
6775+
llvm::Type *Tys[] = {llvm::FixedVectorType::get(RetTy, ElemCount),
6776+
Ops[1]->getType()};
6777+
if (ExtendLane) {
6778+
auto *VT = llvm::FixedVectorType::get(Int8Ty, 16);
6779+
Ops[2] = Builder.CreateInsertVector(VT, PoisonValue::get(VT), Ops[2],
6780+
Builder.getInt64(0));
6781+
}
6782+
llvm::Value *FPM =
6783+
EmitScalarOrConstFoldImmArg(/* ICEArguments */ 0, E->getNumArgs() - 1, E);
6784+
return EmitFP8NeonCall(CGM.getIntrinsic(IID, Tys), Ops, FPM, name);
6785+
}
6786+
67696787
Value *CodeGenFunction::EmitNeonShiftVector(Value *V, llvm::Type *Ty,
67706788
bool neg) {
67716789
int SV = cast<ConstantInt>(V)->getSExtValue();
@@ -12761,6 +12779,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1276112779

1276212780
unsigned Int;
1276312781
bool ExtractLow = false;
12782+
bool ExtendLane = false;
1276412783
switch (BuiltinID) {
1276512784
default: return nullptr;
1276612785
case NEON::BI__builtin_neon_vbsl_v:
@@ -14028,6 +14047,31 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
1402814047
return EmitFP8NeonCvtCall(Intrinsic::aarch64_neon_fp8_fcvtn2, Ty,
1402914048
Ops[1]->getType(), false, Ops, E, "vfcvtn2");
1403014049
}
14050+
14051+
case NEON::BI__builtin_neon_vdot_f16_mf8_fpm:
14052+
case NEON::BI__builtin_neon_vdotq_f16_mf8_fpm:
14053+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2, false, HalfTy,
14054+
Ops, E, "fdot2");
14055+
case NEON::BI__builtin_neon_vdot_lane_f16_mf8_fpm:
14056+
case NEON::BI__builtin_neon_vdotq_lane_f16_mf8_fpm:
14057+
ExtendLane = true;
14058+
LLVM_FALLTHROUGH;
14059+
case NEON::BI__builtin_neon_vdot_laneq_f16_mf8_fpm:
14060+
case NEON::BI__builtin_neon_vdotq_laneq_f16_mf8_fpm:
14061+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot2_lane,
14062+
ExtendLane, HalfTy, Ops, E, "fdot2_lane");
14063+
case NEON::BI__builtin_neon_vdot_f32_mf8_fpm:
14064+
case NEON::BI__builtin_neon_vdotq_f32_mf8_fpm:
14065+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4, false,
14066+
FloatTy, Ops, E, "fdot4");
14067+
case NEON::BI__builtin_neon_vdot_lane_f32_mf8_fpm:
14068+
case NEON::BI__builtin_neon_vdotq_lane_f32_mf8_fpm:
14069+
ExtendLane = true;
14070+
LLVM_FALLTHROUGH;
14071+
case NEON::BI__builtin_neon_vdot_laneq_f32_mf8_fpm:
14072+
case NEON::BI__builtin_neon_vdotq_laneq_f32_mf8_fpm:
14073+
return EmitFP8NeonFDOTCall(Intrinsic::aarch64_neon_fp8_fdot4_lane,
14074+
ExtendLane, FloatTy, Ops, E, "fdot4_lane");
1403114075
case NEON::BI__builtin_neon_vamin_f16:
1403214076
case NEON::BI__builtin_neon_vaminq_f16:
1403314077
case NEON::BI__builtin_neon_vamin_f32:

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4699,6 +4699,10 @@ class CodeGenFunction : public CodeGenTypeCache {
46994699
llvm::Type *Ty1, bool Extract,
47004700
SmallVectorImpl<llvm::Value *> &Ops,
47014701
const CallExpr *E, const char *name);
4702+
llvm::Value *EmitFP8NeonFDOTCall(unsigned IID, bool ExtendLane,
4703+
llvm::Type *RetTy,
4704+
SmallVectorImpl<llvm::Value *> &Ops,
4705+
const CallExpr *E, const char *name);
47024706
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx,
47034707
const llvm::ElementCount &Count);
47044708
llvm::Value *EmitNeonSplat(llvm::Value *V, llvm::Constant *Idx);

0 commit comments

Comments
 (0)