Skip to content

Commit eddb79d

Browse files
authored
1 parent d8ebb08 commit eddb79d

32 files changed

+585
-8
lines changed

clang/docs/ReleaseNotes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ X86 Support
740740
- Support ISA of ``AMX-FP8``.
741741
- Support ISA of ``AMX-TRANSPOSE``.
742742
- Support ISA of ``AMX-AVX512``.
743+
- Support ISA of ``AMX-TF32``.
743744

744745
Arm and AArch64 Support
745746
^^^^^^^^^^^^^^^^^^^^^^^

clang/include/clang/Basic/BuiltinsX86_64.def

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2pbf16l_internal, "V32yUsUsV256iUi", "n",
139139
TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
140140
TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl_internal, "V32xUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
141141
TARGET_BUILTIN(__builtin_ia32_tilemovrow_internal, "V16iUsUsV256iUi", "n", "amx-avx512,avx10.2-512")
142+
TARGET_BUILTIN(__builtin_ia32_tmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32")
143+
TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-tf32,amx-transpose")
144+
142145
// AMX
143146
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
144147
TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
@@ -172,10 +175,6 @@ TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phh, "V32xIUcUi", "n", "amx-avx512,avx10
172175
TARGET_BUILTIN(__builtin_ia32_tcvtrowps2phl, "V32xIUcUi", "n", "amx-avx512,avx10.2-512")
173176
TARGET_BUILTIN(__builtin_ia32_tilemovrow, "V16iIUcUi", "n", "amx-avx512,avx10.2-512")
174177

175-
TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
176-
TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
177-
TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd")
178-
179178
// AMX_FP16 FP16
180179
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps, "vIUcIUcIUc", "n", "amx-fp16")
181180

@@ -185,6 +184,14 @@ TARGET_BUILTIN(__builtin_ia32_tdpbhf8ps, "vIUcUIcUIc", "n", "amx-fp8")
185184
TARGET_BUILTIN(__builtin_ia32_tdphbf8ps, "vIUcUIcUIc", "n", "amx-fp8")
186185
TARGET_BUILTIN(__builtin_ia32_tdphf8ps, "vIUcUIcUIc", "n", "amx-fp8")
187186

187+
// AMX TF32
188+
TARGET_BUILTIN(__builtin_ia32_tmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32")
189+
TARGET_BUILTIN(__builtin_ia32_ttmmultf32ps, "vIUcIUcIUc", "n", "amx-tf32,amx-transpose")
190+
191+
TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
192+
TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
193+
TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiSLLi*SLLiSLLiIi", "n", "cmpccxadd")
194+
188195
// RAO-INT
189196
TARGET_BUILTIN(__builtin_ia32_aadd64, "vv*SOi", "n", "raoint")
190197
TARGET_BUILTIN(__builtin_ia32_aand64, "vv*SOi", "n", "raoint")

clang/include/clang/Driver/Options.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6297,6 +6297,8 @@ def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
62976297
def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
62986298
def mamx_fp8 : Flag<["-"], "mamx-fp8">, Group<m_x86_Features_Group>;
62996299
def mno_amx_fp8 : Flag<["-"], "mno-amx-fp8">, Group<m_x86_Features_Group>;
6300+
def mamx_tf32 : Flag<["-"], "mamx-tf32">, Group<m_x86_Features_Group>;
6301+
def mno_amx_tf32 : Flag<["-"], "mno-amx-tf32">, Group<m_x86_Features_Group>;
63006302
def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
63016303
def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
63026304
def mamx_transpose : Flag<["-"], "mamx-transpose">, Group<m_x86_Features_Group>;

clang/lib/Basic/Targets/X86.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
434434
HasAMXTRANSPOSE = true;
435435
} else if (Feature == "+amx-avx512") {
436436
HasAMXAVX512 = true;
437+
} else if (Feature == "+amx-tf32") {
438+
HasAMXTF32 = true;
437439
} else if (Feature == "+cmpccxadd") {
438440
HasCMPCCXADD = true;
439441
} else if (Feature == "+raoint") {
@@ -959,6 +961,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
959961
Builder.defineMacro("__AMX_TRANSPOSE__");
960962
if (HasAMXAVX512)
961963
Builder.defineMacro("__AMX_AVX512__");
964+
if (HasAMXTF32)
965+
Builder.defineMacro("__AMX_TF32__");
962966
if (HasCMPCCXADD)
963967
Builder.defineMacro("__CMPCCXADD__");
964968
if (HasRAOINT)
@@ -1090,6 +1094,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
10901094
.Case("amx-fp16", true)
10911095
.Case("amx-fp8", true)
10921096
.Case("amx-int8", true)
1097+
.Case("amx-tf32", true)
10931098
.Case("amx-tile", true)
10941099
.Case("amx-transpose", true)
10951100
.Case("avx", true)
@@ -1211,6 +1216,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
12111216
.Case("amx-fp16", HasAMXFP16)
12121217
.Case("amx-fp8", HasAMXFP8)
12131218
.Case("amx-int8", HasAMXINT8)
1219+
.Case("amx-tf32", HasAMXTF32)
12141220
.Case("amx-tile", HasAMXTILE)
12151221
.Case("amx-transpose", HasAMXTRANSPOSE)
12161222
.Case("avx", SSELevel >= AVX)

clang/lib/Basic/Targets/X86.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
160160
bool HasAMXFP8 = false;
161161
bool HasAMXTRANSPOSE = false;
162162
bool HasAMXAVX512 = false;
163+
bool HasAMXTF32 = false;
163164
bool HasSERIALIZE = false;
164165
bool HasTSXLDTRK = false;
165166
bool HasUSERMSR = false;

clang/lib/Headers/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ set(x86_files
151151
amxfp16intrin.h
152152
amxfp8intrin.h
153153
amxintrin.h
154+
amxtf32intrin.h
155+
amxtf32transposeintrin.h
154156
amxtransposeintrin.h
155157
avx10_2_512bf16intrin.h
156158
avx10_2_512convertintrin.h

clang/lib/Headers/amxtf32intrin.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*===------------- amxtf32intrin.h - AMX_TF32 intrinsics -*- C++ -*---------===
2+
*
3+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
* See https://llvm.org/LICENSE.txt for license information.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*
7+
*===------------------------------------------------------------------------===
8+
*/
9+
10+
#ifndef __IMMINTRIN_H
11+
#error "Never use <amxtf32intrin.h> directly; include <immintrin.h> instead."
12+
#endif // __IMMINTRIN_H
13+
14+
#ifndef __AMX_TF32INTRIN_H
15+
#define __AMX_TF32INTRIN_H
16+
#ifdef __x86_64__
17+
18+
#define __DEFAULT_FN_ATTRS_TF32 \
19+
__attribute__((__always_inline__, __nodebug__, __target__("amx-tf32")))
20+
21+
/// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus
22+
/// with \a srcdst.
23+
/// All the calculation is base on float32 but with the lower 13-bit set to 0.
24+
///
25+
/// \headerfile <immintrin.h>
26+
///
27+
/// \code
28+
/// void _tile_mmultf32ps(constexpr int srcdst, constexpr int a, \
29+
/// constexpr int b);
30+
/// \endcode
31+
///
32+
/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
33+
///
34+
/// \param srcdst
35+
/// The destination tile. Max size is 1024 Bytes.
36+
/// \param a
37+
/// The 1st source tile. Max size is 1024 Bytes.
38+
/// \param b
39+
/// The 2nd source tile. Max size is 1024 Bytes.
40+
///
41+
/// \code{.operation}
42+
/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
43+
/// dword[12:0] := 0
44+
/// dword[31:13] := x[31:13]
45+
/// return dword
46+
/// }
47+
///
48+
/// DEFINE silence_snan_fp32(x[31:0]) {
49+
/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
50+
/// x.fraction[22] := 1
51+
/// return x
52+
/// }
53+
///
54+
/// elements_a := a.colsb / 4
55+
/// elements_dest := srcdst.colsb / 4
56+
///
57+
/// FOR m = 0 TO (srcdst.rows-1)
58+
/// tmp[511:0] := 0
59+
/// FOR k = 0 TO (elements_a-1)
60+
/// FOR n = 0 TO (elements_dest-1)
61+
/// af := silence_snan_fp32(a.row[m].fp32[k])
62+
/// bf := silence_snan_fp32(b.row[k].fp32[n])
63+
/// tmp.fp32[n] += zero_lower_mantissa_bits_fp32(af)
64+
/// * zero_lower_mantissa_bits_fp32(bf)
65+
/// ENDFOR
66+
/// ENDFOR
67+
///
68+
/// FOR n = 0 TO (elements_dest-1)
69+
/// tmp.fp32[n] += srcdst.row[m].fp32[n]
70+
/// ENDFOR
71+
/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
72+
///
73+
/// ENDFOR
74+
///
75+
/// zero_upper_rows(srcdst, srcdst.rows)
76+
/// zero_tileconfig_start()
77+
/// \endcode
78+
#define _tile_mmultf32ps(srcdst, a, b) \
79+
__builtin_ia32_tmmultf32ps((srcdst), (a), (b))
80+
81+
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32
82+
_tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
83+
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
84+
return __builtin_ia32_tmmultf32ps_internal(m, n, k, dst, src1, src2);
85+
}
86+
87+
/// Do Matrix Multiplication of src0 and src1, and then do Matrix Plus with dst.
88+
/// All the calculation is base on float32 but with the lower 13-bit set to 0.
89+
///
90+
/// \headerfile <immintrin.h>
91+
///
92+
/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
93+
///
94+
/// \param dst
95+
/// The destination tile. Max size is 1024 Bytes.
96+
/// \param src0
97+
/// The 1st source tile. Max size is 1024 Bytes.
98+
/// \param src1
99+
/// The 2nd source tile. Max size is 1024 Bytes.
100+
__DEFAULT_FN_ATTRS_TF32
101+
static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0,
102+
__tile1024i src1) {
103+
dst->tile = _tile_mmultf32ps_internal(src0.row, src1.col, src0.col, dst->tile,
104+
src0.tile, src1.tile);
105+
}
106+
107+
#endif // __x86_64__
108+
#endif // __AMX_TF32INTRIN_H
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*===--------- amxtf32transposeintrin.h - AMX-TF32 and AMX-TRANSPOSE --------===
2+
*
3+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
* See https://llvm.org/LICENSE.txt for license information.
5+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
*
7+
*===------------------------------------------------------------------------===
8+
*/
9+
#ifndef __IMMINTRIN_H
10+
#error \
11+
"Never use <amxtf32tranposeintrin.h> directly; include <immintrin.h> instead."
12+
#endif // __IMMINTRIN_H
13+
14+
#ifndef __AMX_TF32TRANSPOSEINTRIN_H
15+
#define __AMX_TF32TRANSPOSEINTRIN_H
16+
#ifdef __x86_64__
17+
18+
#define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \
19+
__attribute__((__always_inline__, __nodebug__, \
20+
__target__("amx-tf32,amx-transpose")))
21+
22+
/// \code
23+
/// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \
24+
/// constexpr int b);
25+
/// \endcode
26+
///
27+
/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
28+
///
29+
/// \param srcdst
30+
/// The destination tile. Max size is 1024 Bytes.
31+
/// \param a
32+
/// The 1st source tile. Max size is 1024 Bytes.
33+
/// \param b
34+
/// The 2nd source tile. Max size is 1024 Bytes.
35+
///
36+
/// \code{.operation}
37+
/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
38+
/// dword[12:0] := 0
39+
/// dword[31:13] := x[31:13]
40+
/// return dword
41+
/// }
42+
///
43+
/// DEFINE silence_snan_fp32(x[31:0]) {
44+
/// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
45+
/// x.fraction[22] := 1
46+
/// return x
47+
/// }
48+
///
49+
/// elements_dest:= srcdst.colsb/4
50+
///
51+
/// FOR m := 0 TO (srcdst.rows-1)
52+
/// tmp[511:0] := 0
53+
/// FOR k := 0 TO (a.rows-1)
54+
/// FOR n := 0 TO (elements_dest-1)
55+
/// a1e := silence_snan_fp32(a.row[k].fp32[m])
56+
/// a2e := silence_snan_fp32(b.row[k].fp32[n])
57+
/// s1e := zero_lower_mantissa_bits_fp32(a1e)
58+
/// s2e := zero_lower_mantissa_bits_fp32(a2e)
59+
/// tmp.fp32[n] += s1e * s2e
60+
/// ENDFOR
61+
/// ENDFOR
62+
///
63+
/// FOR n := 0 TO (elements_dest-1)
64+
/// tmp.fp32[n] += srcdst.row[m].fp32[n]
65+
/// ENDFOR
66+
/// write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
67+
///
68+
/// ENDFOR
69+
///
70+
/// zero_upper_rows(srcdst, srcdst.rows)
71+
/// zero_tileconfig_start()
72+
/// \endcode
73+
#define _tile_tmmultf32ps(srcdst, a, b) \
74+
__builtin_ia32_ttmmultf32ps((srcdst), (a), (b))
75+
76+
// dst = m x n (srcdest), src1 = k x m, src2 = k x n
77+
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
78+
_tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
79+
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
80+
return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2);
81+
}
82+
83+
/// Compute transpose and do Matrix Multiplication of src0 and src1, and then do
84+
/// Matrix Plus with dst. All the calculation is base on float32 but with the
85+
/// lower 13-bit set to 0.
86+
///
87+
/// \headerfile <immintrin.h>
88+
///
89+
/// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
90+
///
91+
/// \param dst
92+
/// The destination tile. Max size is 1024 Bytes.
93+
/// \param src0
94+
/// The 1st source tile. Max size is 1024 Bytes.
95+
/// \param src1
96+
/// The 2nd source tile. Max size is 1024 Bytes.
97+
__DEFAULT_FN_ATTRS_TF32_TRANSPOSE
98+
static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0,
99+
__tile1024i src1) {
100+
dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col,
101+
dst->tile, src0.tile, src1.tile);
102+
}
103+
104+
#endif // __x86_64__
105+
#endif // __AMX_TF32TRANSPOSEINTRIN_H

clang/lib/Headers/immintrin.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,15 @@ _storebe_i64(void * __P, long long __D) {
660660
#include <amxavx512intrin.h>
661661
#endif
662662

663+
#if !defined(__SCE__) || __has_feature(modules) || defined(__AMX_TF32__)
664+
#include <amxtf32intrin.h>
665+
#endif
666+
667+
#if !defined(__SCE__) || __has_feature(modules) || \
668+
(defined(__AMX_TF32__) && defined(__AMX_TRANSPOSE__))
669+
#include <amxtf32transposeintrin.h>
670+
#endif
671+
663672
#if !defined(__SCE__) || __has_feature(modules) || \
664673
defined(__AVX512VP2INTERSECT__)
665674
#include <avx512vp2intersectintrin.h>

clang/lib/Sema/SemaX86.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ bool SemaX86::CheckBuiltinTileArguments(unsigned BuiltinID, CallExpr *TheCall) {
654654
case X86::BI__builtin_ia32_tdpbhf8ps:
655655
case X86::BI__builtin_ia32_tdphbf8ps:
656656
case X86::BI__builtin_ia32_tdphf8ps:
657+
case X86::BI__builtin_ia32_tmmultf32ps:
658+
case X86::BI__builtin_ia32_ttmmultf32ps:
657659
return CheckBuiltinTileRangeAndDuplicate(TheCall, {0, 1, 2});
658660
case X86::BI__builtin_ia32_ttransposed:
659661
return CheckBuiltinTileArgumentsRange(TheCall, {0, 1});

clang/test/CodeGen/X86/amx_tf32.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +amx-tile -target-feature +amx-tf32 \
2+
// RUN: -target-feature +amx-transpose -emit-llvm -o - -Wall -Werror -pedantic -Wno-gnu-statement-expression | FileCheck %s
3+
4+
#include <immintrin.h>
5+
#include <stddef.h>
6+
7+
void test_tile_mmultf32ps(void) {
8+
// CHECK-LABEL: @test_tile_mmultf32ps(
9+
// CHECK: call void @llvm.x86.tmmultf32ps(i8 1, i8 2, i8 3)
10+
_tile_mmultf32ps(1, 2, 3);
11+
}
12+
13+
void test_tile_tmmultf32ps(void) {
14+
// CHECK-LABEL: @test_tile_tmmultf32ps(
15+
// CHECK: call void @llvm.x86.ttmmultf32ps(i8 1, i8 2, i8 3)
16+
_tile_tmmultf32ps(1, 2, 3);
17+
}

clang/test/CodeGen/X86/amx_tf32_api.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown \
2+
// RUN: -target-feature +amx-tf32 -target-feature +amx-transpose \
3+
// RUN: -target-feature +amx-bf16 -target-feature +avx512f \
4+
// RUN: -emit-llvm -o - -Werror -pedantic | FileCheck %s
5+
6+
#include <immintrin.h>
7+
8+
char buf[1024];
9+
#define STRIDE 32
10+
11+
char buf2[1024];
12+
13+
void test_tile_mmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) {
14+
//CHECK-LABEL: @test_tile_mmultf32ps
15+
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
16+
//CHECK-DAG: call x86_amx @llvm.x86.tmmultf32ps.internal
17+
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
18+
__tile_mmultf32ps(&c, a, b);
19+
}
20+
21+
void test_tile_tmmultf32ps(__tile1024i a, __tile1024i b, __tile1024i c) {
22+
//CHECK-LABEL: @test_tile_tmmultf32ps
23+
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
24+
//CHECK-DAG: call x86_amx @llvm.x86.ttmmultf32ps.internal
25+
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
26+
__tile_tmmultf32ps(&c, a, b);
27+
}

0 commit comments

Comments
 (0)