Skip to content

Commit 2f743ac

Browse files
authored
[MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. (#111197)
This patch is intended to resolve #109481 and improve the usability of the AMX dialect. In LLVM IR, AMX intrinsics use `x86_amx` which is one of the primitive types. This type is supposed to be used for AMX intrinsic calls and no other operations. AMX dialect of MLIR uses regular 2D vector types, which are then lowered to arrays of vectors in the LLVMIR dialect. This creates an inconsistency in the types used in the LLVMIR dialect and LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require result types to match and that is where tile loads and mul operation results get `x86_amx` type. This works in very simple cases when mul and tile store operations directly consume the result of another AMX intrinsic call, but it doesn't work when an argument is a block argument (phi node). In addition to translation problems, this inconsistency between types used in MLIR and LLVM IR makes MLIR verification and transformation quite problematic. Both `amx.tileload` and `vector::transfer_read` can load values of the same type, but only one of them can be used in AMX operations. In general, by looking at a type of value, we cannot determine if it can only be used for AMX operations or contrary can be used in other operations but AMX ones. To remove this inconsistency and make AMX operations more explicit in their limitations, I propose to add `LLVMX86AMXType` type to the LLVMIR dialect to match `x86_amx` type in LLVM IR, and introduce `amx::TileType` to be used by AMX operations in MLIR. This resolves translation problems for AMX usage with phi nodes and provides proper type verification in MLIR for AMX operations. P.S. This patch also adds missing FP16 support. It's trivial but unrelated to type system changes, so let me know if I should submit it separately. --------- Signed-off-by: Ilya Enkovich <[email protected]>
1 parent 44ab380 commit 2f743ac

File tree

14 files changed

+329
-140
lines changed

14 files changed

+329
-140
lines changed

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 117 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232
include "mlir/Interfaces/SideEffectInterfaces.td"
33+
include "mlir/IR/AttrTypeBase.td"
34+
include "mlir/IR/BuiltinTypes.td"
3335

3436
//===----------------------------------------------------------------------===//
3537
// AMX dialect definition.
@@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
5557
For details, see the Intel documentation:
5658
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
5759
}];
60+
let useDefaultTypePrinterParser = 1;
5861
}
5962

63+
//===----------------------------------------------------------------------===//
64+
// AMX Tile definition.
65+
//===----------------------------------------------------------------------===//
66+
67+
class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68+
: TypeDef<AMX_Dialect, typeName, traits> {
69+
let mnemonic = typeMnemonic;
70+
}
71+
72+
def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73+
let cppFunctionName = "isValidTileTypeElementType";
74+
}
75+
76+
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
77+
let summary = "AMX 2D tile to be used by AMX opertaions.";
78+
79+
let description = [{
80+
This type is used to represent values in AMX tile registers. All AMX operations
81+
work on AMX tiles and these tiles cannot be used in other operations directly.
82+
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83+
element type for IR verification and lowering to LLVMIR dialect.
84+
}];
85+
86+
let parameters = (ins
87+
ArrayRefParameter<"int64_t">:$shape,
88+
AMX_TileTypeElementType:$elementType
89+
);
90+
91+
let builders = [
92+
TypeBuilderWithInferredContext<(ins
93+
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94+
return $_get(elementType.getContext(), shape, elementType);
95+
}]>
96+
];
97+
98+
let extraClassDeclaration = [{
99+
/// Returns if this type is ranked (always true).
100+
bool hasRank() const { return true; }
101+
102+
/// Clone this tile type with the given shape and element type. If the
103+
/// provided shape is `std::nullopt`, the current shape of the type is used.
104+
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105+
Type elementType) const {
106+
return get(shape.value_or(getShape()), elementType);
107+
}
108+
}];
109+
110+
let hasCustomAssemblyFormat = 1;
111+
let skipDefaultBuilders = 1;
112+
}
113+
114+
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
115+
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
116+
117+
class AMXTileOf<list<Type> allowedTypes> :
118+
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
119+
"::mlir::amx::TileType">;
120+
121+
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122+
123+
def AMXTileF32 : AMXTileOf<[F32]>;
124+
125+
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126+
127+
def AMXTileI32 : AMXTileOf<[I32]>;
128+
129+
def AMXTileI8 : AMXTileOf<[I8]>;
130+
60131
//===----------------------------------------------------------------------===//
61132
// AMX Op and IntrOp definitions.
62133
//===----------------------------------------------------------------------===//
@@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88159
Example:
89160

90161
```mlir
91-
%0 = amx.tile_zero : vector<16x16xbf16>
162+
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
92163
```
93164
}];
94-
let results = (outs
95-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
165+
let results = (outs AnyAMXTile:$res);
96166
let extraClassDeclaration = [{
97-
VectorType getVectorType() {
98-
return ::llvm::cast<VectorType>(getRes().getType());
167+
TileType getTileType() {
168+
return ::llvm::cast<TileType>(getRes().getType());
99169
}
100170
}];
101-
let assemblyFormat = "attr-dict `:` type($res)";
171+
let assemblyFormat = "attr-dict `:` qualified(type($res))";
102172
let hasVerifier = 1;
103173
}
104174

@@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117187
Example:
118188

119189
```mlir
120-
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
190+
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
121191
```
122192
}];
123193
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124194
Variadic<Index>:$indices);
125-
let results = (outs
126-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
195+
let results = (outs AnyAMXTile:$res);
127196
let extraClassDeclaration = [{
128197
MemRefType getMemRefType() {
129198
return ::llvm::cast<MemRefType>(getBase().getType());
130199
}
131-
VectorType getVectorType() {
132-
return ::llvm::cast<VectorType>(getRes().getType());
200+
TileType getTileType() {
201+
return ::llvm::cast<TileType>(getRes().getType());
133202
}
134203
}];
135204
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
136-
"type($base) `into` type($res)";
205+
"type($base) `into` qualified(type($res))";
137206
let hasVerifier = 1;
138207
}
139208

@@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148217
Example:
149218

150219
```mlir
151-
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
220+
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
152221
```
153222
}];
154223
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155224
Variadic<Index>:$indices,
156-
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
225+
AnyAMXTile:$val);
157226
let extraClassDeclaration = [{
158227
MemRefType getMemRefType() {
159228
return ::llvm::cast<MemRefType>(getBase().getType());
160229
}
161-
VectorType getVectorType() {
162-
return ::llvm::cast<VectorType>(getVal().getType());
230+
TileType getTileType() {
231+
return ::llvm::cast<TileType>(getVal().getType());
163232
}
164233
}];
165234
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
166-
"type($base) `,` type($val)";
235+
"type($base) `,` qualified(type($val))";
167236
let hasVerifier = 1;
168237
}
169238

@@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
184253

185254
```mlir
186255
%0 = amx.tile_mulf %a, %b, %c
187-
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
256+
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
188257
```
189258
}];
190-
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
191-
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
192-
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
193-
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
259+
let arguments = (ins AMXTileF16OrBF16:$lhs,
260+
AMXTileF16OrBF16:$rhs,
261+
AMXTileF32:$acc);
262+
let results = (outs AMXTileF32:$res);
194263
let extraClassDeclaration = [{
195-
VectorType getLhsVectorType() {
196-
return ::llvm::cast<VectorType>(getLhs().getType());
264+
TileType getLhsTileType() {
265+
return ::llvm::cast<TileType>(getLhs().getType());
197266
}
198-
VectorType getRhsVectorType() {
199-
return ::llvm::cast<VectorType>(getRhs().getType());
267+
TileType getRhsTileType() {
268+
return ::llvm::cast<TileType>(getRhs().getType());
200269
}
201-
VectorType getVectorType() {
202-
return ::llvm::cast<VectorType>(getRes().getType());
270+
TileType getTileType() {
271+
return ::llvm::cast<TileType>(getRes().getType());
203272
}
204273
}];
205274
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
206-
"type($lhs) `,` type($rhs) `,` type($acc) ";
275+
"qualified(type($lhs)) `,` qualified(type($rhs))"
276+
" `,` qualified(type($acc)) ";
207277
let hasVerifier = 1;
208278
}
209279

@@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [
223293

224294
```mlir
225295
%0 = amx.tile_muli %a zext, %b zext, %c
226-
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
296+
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
227297
```
228298
}];
229-
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
230-
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
231-
VectorOfRankAndType<[2], [I32, I8]>:$acc,
299+
let arguments = (ins AMXTileI8:$lhs,
300+
AMXTileI8:$rhs,
301+
AMXTileI32:$acc,
232302
UnitAttr:$isZextLhs,
233303
UnitAttr:$isZextRhs
234304
);
235-
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
305+
let results = (outs AMXTileI32:$res);
236306
let extraClassDeclaration = [{
237-
VectorType getLhsVectorType() {
238-
return ::llvm::cast<VectorType>(getLhs().getType());
307+
TileType getLhsTileType() {
308+
return ::llvm::cast<TileType>(getLhs().getType());
239309
}
240-
VectorType getRhsVectorType() {
241-
return ::llvm::cast<VectorType>(getRhs().getType());
310+
TileType getRhsTileType() {
311+
return ::llvm::cast<TileType>(getRhs().getType());
242312
}
243-
VectorType getVectorType() {
244-
return ::llvm::cast<VectorType>(getRes().getType());
313+
TileType getTileType() {
314+
return ::llvm::cast<TileType>(getRes().getType());
245315
}
246316
}];
247317
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
248-
"type($lhs) `,` type($rhs) `,` type($acc) ";
318+
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
249319
let hasVerifier = 1;
250320
}
251321

@@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286356
AnyInteger,
287357
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288358

359+
// Dot product of f16 tiles into f32 tile.
360+
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361+
Arguments<(ins AnyInteger,
362+
AnyInteger,
363+
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364+
289365
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290366
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291367
Arguments<(ins AnyInteger,

mlir/include/mlir/Dialect/AMX/AMXDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
2323

24+
#define GET_TYPEDEF_CLASSES
25+
#include "mlir/Dialect/AMX/AMXTypes.h.inc"
26+
2427
#define GET_OP_CLASSES
2528
#include "mlir/Dialect/AMX/AMX.h.inc"
2629

mlir/include/mlir/Dialect/AMX/Transforms.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@ namespace mlir {
1414
class LLVMConversionTarget;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
17+
class DialectRegistry;
1718

1819
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
1920
/// intrinsics.
20-
void populateAMXLegalizeForLLVMExportPatterns(
21-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
21+
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
22+
RewritePatternSet &patterns);
2223

2324
/// Configure the target to support lowering AMX ops to ops that map to LLVM
2425
/// intrinsics.
2526
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
2627

28+
/// Register LLVM conversion interface for AMX dialect.
29+
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
30+
2731
} // namespace mlir
2832

2933
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
250250
}];
251251
}
252252

253+
//===----------------------------------------------------------------------===//
254+
// LLVMX86AMXType
255+
//===----------------------------------------------------------------------===//
256+
257+
def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
258+
let summary = "LLVM x86_amx type.";
259+
let description = [{
260+
The x86_amx type represents a value held in an AMX tile register on an x86
261+
machine. Can only be used in AMX intrinsics calls.
262+
}];
263+
}
264+
253265
#endif // LLVMTYPES_TD

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
2525
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
2626
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
27+
#include "mlir/Dialect/AMX/Transforms.h"
2728
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
2829
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
2930
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7071
registerConvertNVVMToLLVMInterface(registry);
7172
registerConvertOpenMPToLLVMInterface(registry);
7273
ub::registerConvertUBToLLVMInterface(registry);
74+
registerConvertAMXToLLVMInterface(registry);
7375

7476
// Register all transform dialect extensions.
7577
affine::registerTransformDialectExtension(registry);

0 commit comments

Comments
 (0)