30
30
31
31
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32
32
include "mlir/Interfaces/SideEffectInterfaces.td"
33
+ include "mlir/IR/AttrTypeBase.td"
34
+ include "mlir/IR/BuiltinTypes.td"
33
35
34
36
//===----------------------------------------------------------------------===//
35
37
// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
55
57
For details, see the Intel documentation:
56
58
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
57
59
}];
60
+ let useDefaultTypePrinterParser = 1;
58
61
}
59
62
63
+ //===----------------------------------------------------------------------===//
64
+ // AMX Tile definition.
65
+ //===----------------------------------------------------------------------===//
66
+
67
+ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68
+ : TypeDef<AMX_Dialect, typeName, traits> {
69
+ let mnemonic = typeMnemonic;
70
+ }
71
+
72
+ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73
+ let cppFunctionName = "isValidTileTypeElementType";
74
+ }
75
+
76
+ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
77
+ let summary = "AMX 2D tile to be used by AMX opertaions.";
78
+
79
+ let description = [{
80
+ This type is used to represent values in AMX tile registers. All AMX operations
81
+ work on AMX tiles and these tiles cannot be used in other operations directly.
82
+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83
+ element type for IR verification and lowering to LLVMIR dialect.
84
+ }];
85
+
86
+ let parameters = (ins
87
+ ArrayRefParameter<"int64_t">:$shape,
88
+ AMX_TileTypeElementType:$elementType
89
+ );
90
+
91
+ let builders = [
92
+ TypeBuilderWithInferredContext<(ins
93
+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94
+ return $_get(elementType.getContext(), shape, elementType);
95
+ }]>
96
+ ];
97
+
98
+ let extraClassDeclaration = [{
99
+ /// Returns if this type is ranked (always true).
100
+ bool hasRank() const { return true; }
101
+
102
+ /// Clone this tile type with the given shape and element type. If the
103
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
104
+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105
+ Type elementType) const {
106
+ return get(shape.value_or(getShape()), elementType);
107
+ }
108
+ }];
109
+
110
+ let hasCustomAssemblyFormat = 1;
111
+ let skipDefaultBuilders = 1;
112
+ }
113
+
114
+ def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
115
+
116
+ def IsAMX2DTilePred : And<[IsAMXTilePred,
117
+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
118
+
119
+ class AMX2DTileOf<list<Type> allowedTypes> :
120
+ ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
121
+ "::mlir::amx::TileType">;
122
+
60
123
//===----------------------------------------------------------------------===//
61
124
// AMX Op and IntrOp definitions.
62
125
//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88
151
Example:
89
152
90
153
```mlir
91
- %0 = amx.tile_zero : vector <16x16xbf16>
154
+ %0 = amx.tile_zero : <16x16xbf16>
92
155
```
93
156
}];
94
157
let results = (outs
95
- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$res);
158
+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$res);
96
159
let extraClassDeclaration = [{
97
- VectorType getVectorType () {
98
- return ::llvm::cast<VectorType >(getRes().getType());
160
+ TileType getTileType () {
161
+ return ::llvm::cast<TileType >(getRes().getType());
99
162
}
100
163
}];
101
164
let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117
180
Example:
118
181
119
182
```mlir
120
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector <16x64xi8>
183
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
121
184
```
122
185
}];
123
186
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124
187
Variadic<Index>:$indices);
125
188
let results = (outs
126
- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$res);
189
+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$res);
127
190
let extraClassDeclaration = [{
128
191
MemRefType getMemRefType() {
129
192
return ::llvm::cast<MemRefType>(getBase().getType());
130
193
}
131
- VectorType getVectorType () {
132
- return ::llvm::cast<VectorType >(getRes().getType());
194
+ TileType getTileType () {
195
+ return ::llvm::cast<TileType >(getRes().getType());
133
196
}
134
197
}];
135
198
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148
211
Example:
149
212
150
213
```mlir
151
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector <16x64xi8>
214
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
152
215
```
153
216
}];
154
217
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155
218
Variadic<Index>:$indices,
156
- VectorOfRankAndType<[2], [F32 , BF16, I32, I8]>:$val);
219
+ AMX2DTileOf<[F32, F16 , BF16, I32, I8]>:$val);
157
220
let extraClassDeclaration = [{
158
221
MemRefType getMemRefType() {
159
222
return ::llvm::cast<MemRefType>(getBase().getType());
160
223
}
161
- VectorType getVectorType () {
162
- return ::llvm::cast<VectorType >(getVal().getType());
224
+ TileType getTileType () {
225
+ return ::llvm::cast<TileType >(getVal().getType());
163
226
}
164
227
}];
165
228
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
183
246
Example:
184
247
185
248
```mlir
186
- %0 = amx.tile_mulf %a, %b, %c
187
- : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
249
+ %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
188
250
```
189
251
}];
190
- let arguments = (ins VectorOfRankAndType<[2], [F32 , BF16]>:$lhs,
191
- VectorOfRankAndType<[2], [F32 , BF16]>:$rhs,
192
- VectorOfRankAndType<[2], [ F32, BF16 ]>:$acc);
193
- let results = (outs VectorOfRankAndType<[2], [ F32, BF16 ]>:$res);
252
+ let arguments = (ins AMX2DTileOf<[F16 , BF16]>:$lhs,
253
+ AMX2DTileOf<[F16 , BF16]>:$rhs,
254
+ AMX2DTileOf<[ F32]>:$acc);
255
+ let results = (outs AMX2DTileOf<[ F32]>:$res);
194
256
let extraClassDeclaration = [{
195
- VectorType getLhsVectorType () {
196
- return ::llvm::cast<VectorType >(getLhs().getType());
257
+ TileType getLhsTileType () {
258
+ return ::llvm::cast<TileType >(getLhs().getType());
197
259
}
198
- VectorType getRhsVectorType () {
199
- return ::llvm::cast<VectorType >(getRhs().getType());
260
+ TileType getRhsTileType () {
261
+ return ::llvm::cast<TileType >(getRhs().getType());
200
262
}
201
- VectorType getVectorType () {
202
- return ::llvm::cast<VectorType >(getRes().getType());
263
+ TileType getTileType () {
264
+ return ::llvm::cast<TileType >(getRes().getType());
203
265
}
204
266
}];
205
267
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
222
284
Example:
223
285
224
286
```mlir
225
- %0 = amx.tile_muli %a zext, %b zext, %c
226
- : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
287
+ %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
227
288
```
228
289
}];
229
- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
230
- VectorOfRankAndType<[2], [I32, I8]>:$rhs,
231
- VectorOfRankAndType<[2], [ I32, I8 ]>:$acc,
290
+ let arguments = (ins AMX2DTileOf<[ I8]>:$lhs,
291
+ AMX2DTileOf<[ I8]>:$rhs,
292
+ AMX2DTileOf<[ I32]>:$acc,
232
293
UnitAttr:$isZextLhs,
233
294
UnitAttr:$isZextRhs
234
295
);
235
- let results = (outs VectorOfRankAndType<[2], [ I32, I8 ]>:$res);
296
+ let results = (outs AMX2DTileOf<[ I32]>:$res);
236
297
let extraClassDeclaration = [{
237
- VectorType getLhsVectorType () {
238
- return ::llvm::cast<VectorType >(getLhs().getType());
298
+ TileType getLhsTileType () {
299
+ return ::llvm::cast<TileType >(getLhs().getType());
239
300
}
240
- VectorType getRhsVectorType () {
241
- return ::llvm::cast<VectorType >(getRhs().getType());
301
+ TileType getRhsTileType () {
302
+ return ::llvm::cast<TileType >(getRhs().getType());
242
303
}
243
- VectorType getVectorType () {
244
- return ::llvm::cast<VectorType >(getRes().getType());
304
+ TileType getTileType () {
305
+ return ::llvm::cast<TileType >(getRes().getType());
245
306
}
246
307
}];
247
308
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286
347
AnyInteger,
287
348
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288
349
350
+ // Dot product of f16 tiles into f32 tile.
351
+ def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
352
+ Arguments<(ins AnyInteger,
353
+ AnyInteger,
354
+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
355
+
289
356
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290
357
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291
358
Arguments<(ins AnyInteger,
0 commit comments