Skip to content

Commit 2a76b2d

Browse files
committed
Utilize x86_amx type for AMX dialect in MLIR.
Signed-off-by: Ilya Enkovich <[email protected]>
1 parent e577f14 commit 2a76b2d

File tree

15 files changed

+287
-126
lines changed

15 files changed

+287
-126
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
1616

1717
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18+
#include "mlir/Dialect/AMX/AMXDialect.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021

@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
258259
/// Convert a 1D vector type into an LLVM vector type.
259260
FailureOr<Type> convertVectorType(VectorType type) const;
260261

262+
/// Convert AMX tile type x86_amx type.
263+
Type convertAMXTileType(amx::TileType type) const;
264+
261265
/// Options for customizing the llvm lowering.
262266
LowerToLLVMOptions options;
263267

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232
include "mlir/Interfaces/SideEffectInterfaces.td"
33+
include "mlir/IR/AttrTypeBase.td"
34+
include "mlir/IR/BuiltinTypes.td"
3335

3436
//===----------------------------------------------------------------------===//
3537
// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
5557
For details, see the Intel documentation:
5658
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
5759
}];
60+
let useDefaultTypePrinterParser = 1;
5861
}
5962

63+
//===----------------------------------------------------------------------===//
64+
// AMX Tile definition.
65+
//===----------------------------------------------------------------------===//
66+
67+
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68+
: TypeDef<AMX_Dialect, typeName, traits> {
69+
let mnemonic = typeMnemonic;
70+
}
71+
72+
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73+
let cppFunctionName = "isValidTileTypeElementType";
74+
}
75+
76+
def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
77+
let summary = "AMX 2D tile to be used by AMX opertaions.";
78+
79+
let description = [{
80+
This type is used to represent values in AMX tile registers. All AMX operations
81+
work on AMX tiles and these tiles cannot be used in other operations directly.
82+
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83+
element type for IR verification and lowering to LLVMIR dialect.
84+
}];
85+
86+
let parameters = (ins
87+
ArrayRefParameter<"int64_t">:$shape,
88+
AMX_TileTypeElementType:$elementType
89+
);
90+
91+
let builders = [
92+
TypeBuilderWithInferredContext<(ins
93+
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94+
return $_get(elementType.getContext(), shape, elementType);
95+
}]>
96+
];
97+
98+
let extraClassDeclaration = [{
99+
/// Returns if this type is ranked (always true).
100+
bool hasRank() const { return true; }
101+
102+
/// Clone this tile type with the given shape and element type. If the
103+
/// provided shape is `std::nullopt`, the current shape of the type is used.
104+
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105+
Type elementType) const {
106+
return get(shape.value_or(getShape()), elementType);
107+
}
108+
}];
109+
110+
let hasCustomAssemblyFormat = 1;
111+
let skipDefaultBuilders = 1;
112+
}
113+
114+
def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
115+
116+
def IsAMX2DTilePred : And<[IsAMXTilePred,
117+
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
118+
119+
class AMX2DTileOf<list<Type> allowedTypes> :
120+
ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
121+
"::mlir::amx::TileType">;
122+
60123
//===----------------------------------------------------------------------===//
61124
// AMX Op and IntrOp definitions.
62125
//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88151
Example:
89152

90153
```mlir
91-
%0 = amx.tile_zero : vector<16x16xbf16>
154+
%0 = amx.tile_zero : <16x16xbf16>
92155
```
93156
}];
94157
let results = (outs
95-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
158+
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
96159
let extraClassDeclaration = [{
97-
VectorType getVectorType() {
98-
return ::llvm::cast<VectorType>(getRes().getType());
160+
TileType getTileType() {
161+
return ::llvm::cast<TileType>(getRes().getType());
99162
}
100163
}];
101164
let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117180
Example:
118181

119182
```mlir
120-
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
183+
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
121184
```
122185
}];
123186
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124187
Variadic<Index>:$indices);
125188
let results = (outs
126-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
189+
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
127190
let extraClassDeclaration = [{
128191
MemRefType getMemRefType() {
129192
return ::llvm::cast<MemRefType>(getBase().getType());
130193
}
131-
VectorType getVectorType() {
132-
return ::llvm::cast<VectorType>(getRes().getType());
194+
TileType getTileType() {
195+
return ::llvm::cast<TileType>(getRes().getType());
133196
}
134197
}];
135198
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148211
Example:
149212

150213
```mlir
151-
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
214+
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
152215
```
153216
}];
154217
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155218
Variadic<Index>:$indices,
156-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
219+
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
157220
let extraClassDeclaration = [{
158221
MemRefType getMemRefType() {
159222
return ::llvm::cast<MemRefType>(getBase().getType());
160223
}
161-
VectorType getVectorType() {
162-
return ::llvm::cast<VectorType>(getVal().getType());
224+
TileType getTileType() {
225+
return ::llvm::cast<TileType>(getVal().getType());
163226
}
164227
}];
165228
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
183246
Example:
184247

185248
```mlir
186-
%0 = amx.tile_mulf %a, %b, %c
187-
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
249+
%0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
188250
```
189251
}];
190-
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
191-
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
192-
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
193-
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
252+
let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
253+
AMX2DTileOf<[F16, BF16]>:$rhs,
254+
AMX2DTileOf<[F32]>:$acc);
255+
let results = (outs AMX2DTileOf<[F32]>:$res);
194256
let extraClassDeclaration = [{
195-
VectorType getLhsVectorType() {
196-
return ::llvm::cast<VectorType>(getLhs().getType());
257+
TileType getLhsTileType() {
258+
return ::llvm::cast<TileType>(getLhs().getType());
197259
}
198-
VectorType getRhsVectorType() {
199-
return ::llvm::cast<VectorType>(getRhs().getType());
260+
TileType getRhsTileType() {
261+
return ::llvm::cast<TileType>(getRhs().getType());
200262
}
201-
VectorType getVectorType() {
202-
return ::llvm::cast<VectorType>(getRes().getType());
263+
TileType getTileType() {
264+
return ::llvm::cast<TileType>(getRes().getType());
203265
}
204266
}];
205267
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
222284
Example:
223285

224286
```mlir
225-
%0 = amx.tile_muli %a zext, %b zext, %c
226-
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
287+
%0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
227288
```
228289
}];
229-
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
230-
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
231-
VectorOfRankAndType<[2], [I32, I8]>:$acc,
290+
let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
291+
AMX2DTileOf<[I8]>:$rhs,
292+
AMX2DTileOf<[I32]>:$acc,
232293
UnitAttr:$isZextLhs,
233294
UnitAttr:$isZextRhs
234295
);
235-
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
296+
let results = (outs AMX2DTileOf<[I32]>:$res);
236297
let extraClassDeclaration = [{
237-
VectorType getLhsVectorType() {
238-
return ::llvm::cast<VectorType>(getLhs().getType());
298+
TileType getLhsTileType() {
299+
return ::llvm::cast<TileType>(getLhs().getType());
239300
}
240-
VectorType getRhsVectorType() {
241-
return ::llvm::cast<VectorType>(getRhs().getType());
301+
TileType getRhsTileType() {
302+
return ::llvm::cast<TileType>(getRhs().getType());
242303
}
243-
VectorType getVectorType() {
244-
return ::llvm::cast<VectorType>(getRes().getType());
304+
TileType getTileType() {
305+
return ::llvm::cast<TileType>(getRes().getType());
245306
}
246307
}];
247308
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286347
AnyInteger,
287348
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288349

350+
// Dot product of f16 tiles into f32 tile.
351+
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
352+
Arguments<(ins AnyInteger,
353+
AnyInteger,
354+
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
355+
289356
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290357
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291358
Arguments<(ins AnyInteger,

mlir/include/mlir/Dialect/AMX/AMXDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
2323

24+
#define GET_TYPEDEF_CLASSES
25+
#include "mlir/Dialect/AMX/AMXTypes.h.inc"
26+
2427
#define GET_OP_CLASSES
2528
#include "mlir/Dialect/AMX/AMX.h.inc"
2629

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
250250
}];
251251
}
252252

253+
//===----------------------------------------------------------------------===//
254+
// LLVMX86AMXType
255+
//===----------------------------------------------------------------------===//
256+
257+
def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
258+
let summary = "LLVM x86_amx type.";
259+
let description = [{
260+
The x86_amx type represents a value held in an AMX tile register on an x86
261+
machine. Can only be used in AMX intrinsics calls.
262+
}];
263+
}
264+
253265
#endif // LLVMTYPES_TD

mlir/lib/Conversion/LLVMCommon/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
1212
Core
1313

1414
LINK_LIBS PUBLIC
15+
MLIRAMXDialect
1516
MLIRIR
1617
MLIRLLVMDialect
1718
MLIRSupport

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
6767
return std::nullopt;
6868
return llvmType;
6969
});
70+
addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
7071

7172
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
7273
// before the conversions below since conversions are attempted in reverse
@@ -594,6 +595,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
594595
return vectorType;
595596
}
596597

598+
/// Convert an AMX tile type to LLVM x86_amx type.
599+
/// Shape and element type of the tile are ignored.
600+
Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
601+
return LLVM::LLVMX86AMXType::get(&getContext());
602+
}
603+
597604
/// Convert a type in the context of the default or bare pointer calling
598605
/// convention. Calling convention sensitive types, such as MemRefType and
599606
/// UnrankedMemRefType, are converted following the specific rules for the

0 commit comments

Comments
 (0)