Skip to content

Commit 8127162

Browse files
authored
1 parent 14f3cdc commit 8127162

File tree

24 files changed

+384
-1
lines changed

24 files changed

+384
-1
lines changed

clang/docs/ReleaseNotes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ X86 Support
661661

662662
- Supported intrinsics for ``MOVRS AND AVX10.2``.
663663
* Supported intrinsics of ``_mm(256|512)_(mask(z))_loadrs_epi(8|16|32|64)``.
664+
- Support ISA of ``AMX-FP8``.
664665

665666
Arm and AArch64 Support
666667
^^^^^^^^^^^^^^^^^^^^^^^

clang/include/clang/Basic/BuiltinsX86_64.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiv*SLLiSLLiIi", "n", "cmpccxadd")
155155
// AMX_FP16 FP16
156156
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16")
157157

158+
// AMX FP8
159+
TARGET_BUILTIN(__builtin_ia32_tdpbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
160+
TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8")
161+
TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
162+
TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8")
163+
158164
// RAO-INT
159165
TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint")
160166
TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint")

clang/include/clang/Driver/Options.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6300,6 +6300,8 @@ def mamx_fp16 : Flag<["-"], "mamx-fp16">, Group<m_x86_Features_Group>;
63006300
def mno_amx_fp16 : Flag<["-"], "mno-amx-fp16">, Group<m_x86_Features_Group>;
63016301
def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
63026302
def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
6303+
def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
6304+
def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
63036305
def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
63046306
def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
63056307
def mcmpccxadd : Flag<["-"], "mcmpccxadd">, Group<m_x86_Features_Group>;

clang/lib/Basic/Targets/X86.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
428428
HasAMXTILE = true;
429429
} else if (Feature == "+amx-complex") {
430430
HasAMXCOMPLEX = true;
431+
} else if (Feature == "+amx-fp8") {
432+
HasAMXFP8 = true;
431433
} else if (Feature == "+cmpccxadd") {
432434
HasCMPCCXADD = true;
433435
} else if (Feature == "+raoint") {
@@ -947,6 +949,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
947949
Builder.defineMacro("__AMX_FP16__");
948950
if (HasAMXCOMPLEX)
949951
Builder.defineMacro("__AMX_COMPLEX__");
952+
if (HasAMXFP8)
953+
Builder.defineMacro("__AMX_FP8__");
950954
if (HasCMPCCXADD)
951955
Builder.defineMacro("__CMPCCXADD__");
952956
if (HasRAOINT)
@@ -1077,6 +1081,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
10771081
.Case("amx-fp16", true)
10781082
.Case("amx-int8", true)
10791083
.Case("amx-tile", true)
1084+
.Case("amx-fp8", true)
10801085
.Case("avx", true)
10811086
.Case("avx10.1-256", true)
10821087
.Case("avx10.1-512", true)
@@ -1195,6 +1200,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
11951200
.Case("amx-fp16", HasAMXFP16)
11961201
.Case("amx-int8", HasAMXINT8)
11971202
.Case("amx-tile", HasAMXTILE)
1203+
.Case("amx-fp8", HasAMXFP8)
11981204
.Case("avx", SSELevel >= AVX)
11991205
.Case("avx10.1-256", HasAVX10_1)
12001206
.Case("avx10.1-512", HasAVX10_1_512)

clang/lib/Basic/Targets/X86.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
157157
bool HasAMXINT8 = false;
158158
bool HasAMXBF16 = false;
159159
bool HasAMXCOMPLEX = false;
160+
bool HasAMXFP8 = false;
160161
bool HasSERIALIZE = false;
161162
bool HasTSXLDTRK = false;
162163
bool HasUSERMSR = false;

clang/lib/Headers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ set(x86_files
149149
amxcomplexintrin.h
150150
amxfp16intrin.h
151151
amxintrin.h
152+
amxfp8intrin.h
152153
avx10_2_512bf16intrin.h
153154
avx10_2_512convertintrin.h
154155
avx10_2_512minmaxintrin.h

clang/lib/Headers/amxfp8intrin.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*===------------- amxfp8intrin.h - AMX intrinsics -*- C++ -*----------------===
2+
*
3+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
* See https://llvm.org/LICENSE.txt for license information.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*
7+
*===------------------------------------------------------------------------===
8+
*/
9+
10+
#ifndef __IMMINTRIN_H
11+
#error "Never use <amxfp8intrin.h> directly; include <immintrin.h> instead."
12+
#endif /* __IMMINTRIN_H */
13+
14+
#ifndef __AMXFP8INTRIN_H
15+
#define __AMXFP8INTRIN_H
16+
#ifdef __x86_64__
17+
18+
/// Peform the dot product of a BF8 value \a a by a BF8 value \a b accumulating
19+
/// into a Single Precision (FP32) source/dest \a dst.
20+
///
21+
/// \headerfile <immintrin.h>
22+
///
23+
/// \code
24+
/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
25+
/// \endcode
26+
///
27+
/// This intrinsic corresponds to the \c TDPBF8PS instruction.
28+
///
29+
/// \param dst
30+
/// The destination tile. Max size is 1024 Bytes.
31+
/// \param a
32+
/// The 1st source tile. Max size is 1024 Bytes.
33+
/// \param b
34+
/// The 2nd source tile. Max size is 1024 Bytes.
35+
#define _tile_dpbf8ps(dst, a, b) __builtin_ia32_tdpbf8ps((dst), (a), (b))
36+
37+
/// Perform the dot product of a BF8 value \a a by an HF8 value \a b
38+
/// accumulating into a Single Precision (FP32) source/dest \a dst.
39+
///
40+
/// \headerfile <immintrin.h>
41+
///
42+
/// \code
43+
/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
44+
/// \endcode
45+
///
46+
/// This intrinsic corresponds to the \c TDPBHF8PS instruction.
47+
///
48+
/// \param dst
49+
/// The destination tile. Max size is 1024 Bytes.
50+
/// \param a
51+
/// The 1st source tile. Max size is 1024 Bytes.
52+
/// \param b
53+
/// The 2nd source tile. Max size is 1024 Bytes.
54+
#define _tile_dpbhf8ps(dst, a, b) __builtin_ia32_tdpbhf8ps((dst), (a), (b))
55+
56+
/// Perform the dot product of an HF8 value \a a by a BF8 value \a b
57+
/// accumulating into a Single Precision (FP32) source/dest \a dst.
58+
///
59+
/// \headerfile <immintrin.h>
60+
///
61+
/// \code
62+
/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
63+
/// \endcode
64+
///
65+
/// This intrinsic corresponds to the \c TDPHBF8PS instruction.
66+
///
67+
/// \param dst
68+
/// The destination tile. Max size is 1024 Bytes.
69+
/// \param a
70+
/// The 1st source tile. Max size is 1024 Bytes.
71+
/// \param b
72+
/// The 2nd source tile. Max size is 1024 Bytes.
73+
#define _tile_dphbf8ps(dst, a, b) __builtin_ia32_tdphbf8ps((dst), (a), (b))
74+
75+
/// Perform the dot product of an HF8 value \a a by an HF8 value \a b
76+
/// accumulating into a Single Precision (FP32) source/dest \a dst.
77+
///
78+
/// \headerfile <immintrin.h>
79+
///
80+
/// \code
81+
/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
82+
/// \endcode
83+
///
84+
/// This intrinsic corresponds to the \c TDPHF8PS instruction.
85+
///
86+
/// \param dst
87+
/// The destination tile. Max size is 1024 Bytes.
88+
/// \param a
89+
/// The 1st source tile. Max size is 1024 Bytes.
90+
/// \param b
91+
/// The 2nd source tile. Max size is 1024 Bytes.
92+
#define _tile_dphf8ps(dst, a, b) __builtin_ia32_tdphf8ps((dst), (a), (b))
93+
94+
#endif /* __x86_64__ */
95+
#endif /* __AMXFP8INTRIN_H */

clang/lib/Headers/immintrin.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@ _storebe_i64(void * __P, long long __D) {
648648
#include <amxcomplexintrin.h>
649649
#endif
650650

651+
#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_FP8__)
652+
#include <amxfp8intrin.h>
653+
#endif
654+
651655
#if !defined(__SCE__) || __has_feature(modules) || \
652656
defined(__AVX512VP2INTERSECT__)
653657
#include <avx512vp2intersectintrin.h>

clang/lib/Sema/SemaX86.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) {
640640
case X86::BI__builtin_ia32_tdpfp16ps:
641641
case X86::BI__builtin_ia32_tcmmimfp16ps:
642642
case X86::BI__builtin_ia32_tcmmrlfp16ps:
643+
case X86::BI__builtin_ia32_tdpbf8ps:
644+
case X86::BI__builtin_ia32_tdpbhf8ps:
645+
case X86::BI__builtin_ia32_tdphbf8ps:
646+
case X86::BI__builtin_ia32_tdphf8ps:
643647
return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2});
644648
}
645649
}

clang/test/CodeGen/X86/amx_fp8.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 \
2+
// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
3+
#include <immintrin.h>
4+
5+
void test_amx(void *data) {
6+
//CHECK-LABEL: @test_amx
7+
//CHECK: call void @llvm.x86.tdpbf8ps(i8 1, i8 2, i8 3)
8+
_tile_dpbf8ps(1, 2, 3);
9+
}
10+
11+
void test_amx2(void *data) {
12+
//CHECK-LABEL: @test_amx2
13+
//CHECK: call void @llvm.x86.tdpbhf8ps(i8 1, i8 2, i8 3)
14+
_tile_dpbhf8ps(1, 2, 3);
15+
}
16+
17+
void test_amx3(void *data) {
18+
//CHECK-LABEL: @test_amx3
19+
//CHECK: call void @llvm.x86.tdphbf8ps(i8 1, i8 2, i8 3)
20+
_tile_dphbf8ps(1, 2, 3);
21+
}
22+
23+
void test_amx4(void *data) {
24+
//CHECK-LABEL: @test_amx4
25+
//CHECK: call void @llvm.x86.tdphf8ps(i8 1, i8 2, i8 3)
26+
_tile_dphf8ps(1, 2, 3);
27+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-fp8 -verify
2+
3+
#include <immintrin.h>
4+
5+
void test_amx(void *data) {
6+
_tile_dpbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
7+
_tile_dpbhf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
8+
_tile_dphbf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
9+
_tile_dphf8ps(4, 3, 3); // expected-error {{tile arguments must refer to different tiles}}
10+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-fp8 -emit-llvm -o - -Wall -Werror -pedantic | FileCheck %s
2+
3+
void f_tilemul(short a)
4+
{
5+
//CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdpbf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
6+
__asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
7+
"tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
8+
"tdpbf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
9+
"tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
10+
::: "memory", "tmm0", "tmm6", "tmm7");
11+
12+
//CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdpbhf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
13+
__asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
14+
"tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
15+
"tdpbhf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
16+
"tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
17+
::: "memory", "tmm0", "tmm6", "tmm7");
18+
19+
//CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdphbf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
20+
__asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
21+
"tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
22+
"tdphbf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
23+
"tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
24+
::: "memory", "tmm0", "tmm6", "tmm7");
25+
26+
//CHECK: call void asm sideeffect "tileloadd 0(%rsi,%r13,4), %tmm0 \0A\09tileloadd 0(%rdx,%r14,4), %tmm6 \0A\09tdphf8ps %tmm6, %tmm0, %tmm7 \0A\09tilestored %tmm7, 0(%r12,%r15,4) \0A\09", "~{memory},~{tmm0},~{tmm6},~{tmm7},~{dirflag},~{fpsr},~{flags}"()
27+
__asm__ volatile ("tileloadd 0(%%rsi,%%r13,4), %%tmm0 \n\t"
28+
"tileloadd 0(%%rdx,%%r14,4), %%tmm6 \n\t"
29+
"tdphf8ps %%tmm6, %%tmm0, %%tmm7 \n\t"
30+
"tilestored %%tmm7, 0(%%r12,%%r15,4) \n\t"
31+
::: "memory", "tmm0", "tmm6", "tmm7");
32+
}

llvm/include/llvm/IR/IntrinsicsX86.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5994,6 +5994,23 @@ let TargetPrefix = "x86" in {
59945994
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
59955995
llvm_x86amx_ty, llvm_x86amx_ty,
59965996
llvm_x86amx_ty], []>;
5997+
5998+
def int_x86_tdpbf8ps : ClangBuiltin<"__builtin_ia32_tdpbf8ps">,
5999+
Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
6000+
[ImmArg<ArgIndex<0>>,
6001+
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
6002+
def int_x86_tdpbhf8ps : ClangBuiltin<"__builtin_ia32_tdpbhf8ps">,
6003+
Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
6004+
[ImmArg<ArgIndex<0>>,
6005+
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
6006+
def int_x86_tdphbf8ps : ClangBuiltin<"__builtin_ia32_tdphbf8ps">,
6007+
Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
6008+
[ImmArg<ArgIndex<0>>,
6009+
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
6010+
def int_x86_tdphf8ps : ClangBuiltin<"__builtin_ia32_tdphf8ps">,
6011+
Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
6012+
[ImmArg<ArgIndex<0>>,
6013+
ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>]>;
59976014
}
59986015

59996016
//===----------------------------------------------------------------------===//

llvm/include/llvm/TargetParser/X86TargetParser.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ X86_FEATURE_COMPAT(AVX10_2_512, "avx10.2-512", 0)
264264
//FIXME: make MOVRS _COMPAT defined when gcc landed relate patch.
265265
X86_FEATURE (MOVRS, "movrs")
266266
X86_FEATURE (ZU, "zu")
267+
X86_FEATURE (AMX_FP8, "amx-fp8")
267268
// These features aren't really CPU features, but the frontend can set them.
268269
X86_FEATURE (RETPOLINE_EXTERNAL_THUNK, "retpoline-external-thunk")
269270
X86_FEATURE (RETPOLINE_INDIRECT_BRANCHES, "retpoline-indirect-branches")

llvm/lib/Target/X86/X86.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ def FeatureAMXFP16 : SubtargetFeature<"amx-fp16", "HasAMXFP16", "true",
270270
def FeatureAMXCOMPLEX : SubtargetFeature<"amx-complex", "HasAMXCOMPLEX", "true",
271271
"Support AMX-COMPLEX instructions",
272272
[FeatureAMXTILE]>;
273+
def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",
274+
"Support AMX-FP8 instructions",
275+
[FeatureAMXTILE]>;
273276
def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true",
274277
"Support CMPCCXADD instructions">;
275278
def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true",

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37420,7 +37420,11 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
3742037420
case X86::PTDPBUSD:
3742137421
case X86::PTDPBUUD:
3742237422
case X86::PTDPBF16PS:
37423-
case X86::PTDPFP16PS: {
37423+
case X86::PTDPFP16PS:
37424+
case X86::PTDPBF8PS:
37425+
case X86::PTDPBHF8PS:
37426+
case X86::PTDPHBF8PS:
37427+
case X86::PTDPHF8PS: {
3742437428
unsigned Opc;
3742537429
switch (MI.getOpcode()) {
3742637430
// clang-format off
@@ -37431,6 +37435,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
3743137435
case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
3743237436
case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
3743337437
case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
37438+
case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
37439+
case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
37440+
case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
37441+
case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
3743437442
// clang-format on
3743537443
}
3743637444

llvm/lib/Target/X86/X86InstrAMX.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,42 @@ let Predicates = [HasAMXCOMPLEX, In64BitMode] in {
267267
}
268268
} // SchedRW = [WriteSystem]
269269
}
270+
271+
// AMX-FP8
272+
let Predicates = [HasAMXFP8, In64BitMode] in {
273+
let SchedRW = [WriteSystem] in {
274+
let Constraints = "$src1 = $dst" in {
275+
class AMX_FP8_BASE<bits<8> Opcode, string Opstr> :
276+
I<Opcode, MRMSrcReg4VOp3, (outs TILE:$dst),
277+
(ins TILE:$src1, TILE:$src2, TILE:$src3),
278+
!strconcat(Opstr, "\t{$src3, $src2, $dst|$dst, $src2, $src3}"),
279+
[]>, VEX, VVVV;
280+
}
281+
282+
def TDPBF8PS : AMX_FP8_BASE<0xfd, "tdpbf8ps">, T_MAP5, PS;
283+
def TDPBHF8PS : AMX_FP8_BASE<0xfd, "tdpbhf8ps">, T_MAP5, XD;
284+
def TDPHBF8PS : AMX_FP8_BASE<0xfd, "tdphbf8ps">, T_MAP5, XS;
285+
def TDPHF8PS : AMX_FP8_BASE<0xfd, "tdphf8ps">, T_MAP5, PD;
286+
287+
let usesCustomInserter = 1 in {
288+
// Pseudo instructions, using immediates instead of tile registers.
289+
// To be translated to the actual instructions in X86ISelLowering.cpp
290+
def PTDPBF8PS : PseudoI<(outs),
291+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
292+
[(int_x86_tdpbf8ps timm:$src1, timm:$src2,
293+
timm:$src3)]>;
294+
def PTDPBHF8PS : PseudoI<(outs),
295+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
296+
[(int_x86_tdpbhf8ps timm:$src1, timm:$src2,
297+
timm:$src3)]>;
298+
def PTDPHBF8PS : PseudoI<(outs),
299+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
300+
[(int_x86_tdphbf8ps timm:$src1, timm:$src2,
301+
timm:$src3)]>;
302+
def PTDPHF8PS : PseudoI<(outs),
303+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
304+
[(int_x86_tdphf8ps timm:$src1, timm:$src2,
305+
timm:$src3)]>;
306+
}
307+
}
308+
}

llvm/lib/Target/X86/X86InstrPredicates.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def HasAMXTILE : Predicate<"Subtarget->hasAMXTILE()">;
183183
def HasAMXBF16 : Predicate<"Subtarget->hasAMXBF16()">;
184184
def HasAMXINT8 : Predicate<"Subtarget->hasAMXINT8()">;
185185
def HasAMXCOMPLEX : Predicate<"Subtarget->hasAMXCOMPLEX()">;
186+
def HasAMXFP8 : Predicate<"Subtarget->hasAMXFP8()">;
186187
def HasUINTR : Predicate<"Subtarget->hasUINTR()">;
187188
def HasUSERMSR : Predicate<"Subtarget->hasUSERMSR()">;
188189
def HasCRC32 : Predicate<"Subtarget->hasCRC32()">;

llvm/lib/TargetParser/Host.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,10 @@ const StringMap<bool> sys::getHostCPUFeatures() {
18761876
MaxLevel >= 0x19 && !getX86CpuIDAndInfo(0x19, &EAX, &EBX, &ECX, &EDX);
18771877
Features["widekl"] = HasLeaf7 && HasLeaf19 && ((EBX >> 2) & 1);
18781878

1879+
bool HasLeaf1E = MaxLevel >= 0x1e &&
1880+
!getX86CpuIDAndInfoEx(0x1e, 0x1, &EAX, &EBX, &ECX, &EDX);
1881+
Features["amx-fp8"] = HasLeaf1E && ((EAX >> 4) & 1) && HasAMXSave;
1882+
18791883
bool HasLeaf24 =
18801884
MaxLevel >= 0x24 && !getX86CpuIDAndInfo(0x24, &EAX, &EBX, &ECX, &EDX);
18811885

0 commit comments

Comments
 (0)