Skip to content

Commit 3f9b733

Browse files
committed
[X86][AMX] Support AMX-TRANSPOSE
Ref.: https://cdrdv2.intel.com/v1/dl/getContent/671368
1 parent d1fae59 commit 3f9b733

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2751
-120
lines changed

clang/docs/ReleaseNotes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,8 @@ X86 Support
623623

624624
- All intrinsics in tbmintrin.h can now be used in constant expressions.
625625

626+
- Support ISA of ``AMX-TRANSPOSE``.
627+
626628
Arm and AArch64 Support
627629
^^^^^^^^^^^^^^^^^^^^^^^
628630

clang/include/clang/Basic/BuiltinsX86_64.def

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i",
128128
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp16")
129129
TARGET_BUILTIN(__builtin_ia32_tcmmimfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex")
130130
TARGET_BUILTIN(__builtin_ia32_tcmmrlfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-complex")
131+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
132+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
133+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
134+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1_internal, "vUsUsUsV256i*V256i*vC*z", "n", "amx-transpose")
135+
TARGET_BUILTIN(__builtin_ia32_ttransposed_internal, "V256iUsUsV256i", "n", "amx-transpose")
131136
// AMX
132137
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
133138
TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
@@ -148,6 +153,12 @@ TARGET_BUILTIN(__builtin_ia32_ptwrite64, "vUOi", "n", "ptwrite")
148153
TARGET_BUILTIN(__builtin_ia32_tcmmimfp16ps, "vIUcIUcIUc", "n", "amx-complex")
149154
TARGET_BUILTIN(__builtin_ia32_tcmmrlfp16ps, "vIUcIUcIUc", "n", "amx-complex")
150155

156+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0, "vIUcvC*z", "n", "amx-transpose")
157+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz0t1, "vIUcvC*z", "n","amx-transpose")
158+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1, "vIUcvC*z", "n", "amx-transpose")
159+
TARGET_BUILTIN(__builtin_ia32_t2rpntlvwz1t1, "vIUcvC*z", "n","amx-transpose")
160+
TARGET_BUILTIN(__builtin_ia32_ttransposed, "vIUcIUc", "n", "amx-transpose")
161+
151162
TARGET_BUILTIN(__builtin_ia32_prefetchi, "vvC*Ui", "nc", "prefetchi")
152163
TARGET_BUILTIN(__builtin_ia32_cmpccxadd32, "Siv*SiSiIi", "n", "cmpccxadd")
153164
TARGET_BUILTIN(__builtin_ia32_cmpccxadd64, "SLLiv*SLLiSLLiIi", "n", "cmpccxadd")

clang/include/clang/Driver/Options.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6287,6 +6287,8 @@ def mamx_int8 : Flag<["-"], "mamx-int8">, Group<m_x86_Features_Group>;
62876287
def mno_amx_int8 : Flag<["-"], "mno-amx-int8">, Group<m_x86_Features_Group>;
62886288
def mamx_tile : Flag<["-"], "mamx-tile">, Group<m_x86_Features_Group>;
62896289
def mno_amx_tile : Flag<["-"], "mno-amx-tile">, Group<m_x86_Features_Group>;
6290+
def mamx_transpose : Flag<["-"], "mamx-transpose">, Group<m_x86_Features_Group>;
6291+
def mno_amx_transpose : Flag<["-"], "mno-amx-transpose">, Group<m_x86_Features_Group>;
62906292
def mcmpccxadd : Flag<["-"], "mcmpccxadd">, Group<m_x86_Features_Group>;
62916293
def mno_cmpccxadd : Flag<["-"], "mno-cmpccxadd">, Group<m_x86_Features_Group>;
62926294
def msse : Flag<["-"], "msse">, Group<m_x86_Features_Group>;

clang/lib/Basic/Targets/X86.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
418418
HasAMXTILE = true;
419419
} else if (Feature == "+amx-complex") {
420420
HasAMXCOMPLEX = true;
421+
} else if (Feature == "+amx-transpose") {
422+
HasAMXTRANSPOSE = true;
421423
} else if (Feature == "+cmpccxadd") {
422424
HasCMPCCXADD = true;
423425
} else if (Feature == "+raoint") {
@@ -935,6 +937,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
935937
Builder.defineMacro("__AMX_FP16__");
936938
if (HasAMXCOMPLEX)
937939
Builder.defineMacro("__AMX_COMPLEX__");
940+
if (HasAMXTRANSPOSE)
941+
Builder.defineMacro("__AMX_TRANSPOSE__");
938942
if (HasCMPCCXADD)
939943
Builder.defineMacro("__CMPCCXADD__");
940944
if (HasRAOINT)
@@ -1065,6 +1069,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
10651069
.Case("amx-fp16", true)
10661070
.Case("amx-int8", true)
10671071
.Case("amx-tile", true)
1072+
.Case("amx-transpose", true)
10681073
.Case("avx", true)
10691074
.Case("avx10.1-256", true)
10701075
.Case("avx10.1-512", true)
@@ -1182,6 +1187,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
11821187
.Case("amx-fp16", HasAMXFP16)
11831188
.Case("amx-int8", HasAMXINT8)
11841189
.Case("amx-tile", HasAMXTILE)
1190+
.Case("amx-transpose", HasAMXTRANSPOSE)
11851191
.Case("avx", SSELevel >= AVX)
11861192
.Case("avx10.1-256", HasAVX10_1)
11871193
.Case("avx10.1-512", HasAVX10_1_512)

clang/lib/Basic/Targets/X86.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
156156
bool HasAMXINT8 = false;
157157
bool HasAMXBF16 = false;
158158
bool HasAMXCOMPLEX = false;
159+
bool HasAMXTRANSPOSE = false;
159160
bool HasSERIALIZE = false;
160161
bool HasTSXLDTRK = false;
161162
bool HasUSERMSR = false;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16920,6 +16920,58 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
1692016920
// instruction, but it will create a memset that won't be optimized away.
1692116921
return Builder.CreateMemSet(Ops[0], Ops[1], Ops[2], Align(1), true);
1692216922
}
16923+
// Corresponding to intrisics which will return 2 tiles (tile0_tile1).
16924+
case X86::BI__builtin_ia32_t2rpntlvwz0_internal:
16925+
case X86::BI__builtin_ia32_t2rpntlvwz0t1_internal:
16926+
case X86::BI__builtin_ia32_t2rpntlvwz1_internal:
16927+
case X86::BI__builtin_ia32_t2rpntlvwz1t1_internal: {
16928+
Intrinsic::ID IID;
16929+
switch (BuiltinID) {
16930+
default:
16931+
llvm_unreachable("Unsupported intrinsic!");
16932+
case X86::BI__builtin_ia32_t2rpntlvwz0_internal:
16933+
IID = Intrinsic::x86_t2rpntlvwz0_internal;
16934+
break;
16935+
case X86::BI__builtin_ia32_t2rpntlvwz0t1_internal:
16936+
IID = Intrinsic::x86_t2rpntlvwz0t1_internal;
16937+
break;
16938+
case X86::BI__builtin_ia32_t2rpntlvwz1_internal:
16939+
IID = Intrinsic::x86_t2rpntlvwz1_internal;
16940+
break;
16941+
case X86::BI__builtin_ia32_t2rpntlvwz1t1_internal:
16942+
IID = Intrinsic::x86_t2rpntlvwz1t1_internal;
16943+
break;
16944+
}
16945+
16946+
// Ops = (Row0, Col0, Col1, DstPtr0, DstPtr1, SrcPtr, Stride)
16947+
Value *Call = Builder.CreateCall(CGM.getIntrinsic(IID),
16948+
{Ops[0], Ops[1], Ops[2], Ops[5], Ops[6]});
16949+
16950+
auto *PtrTy = E->getArg(3)->getType()->getAs<PointerType>();
16951+
assert(PtrTy && "arg3 must be of pointer type");
16952+
QualType PtreeTy = PtrTy->getPointeeType();
16953+
llvm::Type *TyPtee = ConvertType(PtreeTy);
16954+
16955+
// Bitcast amx type (x86_amx) to vector type (256 x i32)
16956+
// Then store tile0 into DstPtr0
16957+
Value *T0 = Builder.CreateExtractValue(Call, 0);
16958+
Value *VecT0 = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
16959+
{TyPtee}, {T0});
16960+
Builder.CreateDefaultAlignedStore(VecT0, Ops[3]);
16961+
16962+
// Then store tile1 into DstPtr1
16963+
Value *T1 = Builder.CreateExtractValue(Call, 1);
16964+
Value *VecT1 = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
16965+
{TyPtee}, {T1});
16966+
Value *Store = Builder.CreateDefaultAlignedStore(VecT1, Ops[4]);
16967+
16968+
// Note: Here we escape directly use x86_tilestored64_internal to store
16969+
// the results due to it can't make sure the Mem writen scope. This may
16970+
// cause shapes reloads after first amx intrinsic, which current amx reg-
16971+
// ister allocation has no ability to handle it.
16972+
16973+
return Store;
16974+
}
1692316975
case X86::BI__ud2:
1692416976
// llvm.trap makes a ud2a instruction on x86.
1692516977
return EmitTrapCall(Intrinsic::trap);

clang/lib/Headers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ set(x86_files
149149
amxcomplexintrin.h
150150
amxfp16intrin.h
151151
amxintrin.h
152+
amxtransposeintrin.h
152153
avx10_2_512bf16intrin.h
153154
avx10_2_512convertintrin.h
154155
avx10_2_512minmaxintrin.h

clang/lib/Headers/amxintrin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ static __inline__ void __DEFAULT_FN_ATTRS_TILE _tile_release(void) {
232232
/// bytes. Since there is no 2D type in llvm IR, we use vector type to
233233
/// represent 2D tile and the fixed size is maximum amx tile register size.
234234
typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64)));
235+
typedef int _tile1024i_1024a
236+
__attribute__((__vector_size__(1024), __aligned__(1024)));
235237

236238
/// This is internal intrinsic. C/C++ user should avoid calling it directly.
237239
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8

0 commit comments

Comments
 (0)