30
30
31
31
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32
32
include "mlir/Interfaces/SideEffectInterfaces.td"
33
+ include "mlir/IR/AttrTypeBase.td"
34
+ include "mlir/IR/BuiltinTypes.td"
33
35
34
36
//===----------------------------------------------------------------------===//
35
37
// AMX dialect definition.
@@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
55
57
For details, see the Intel documentation:
56
58
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
57
59
}];
60
+ let useDefaultTypePrinterParser = 1;
58
61
}
59
62
63
+ //===----------------------------------------------------------------------===//
64
+ // AMX Tile definition.
65
+ //===----------------------------------------------------------------------===//
66
+
67
+ class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
68
+ : TypeDef<AMX_Dialect, typeName, traits> {
69
+ let mnemonic = typeMnemonic;
70
+ }
71
+
72
+ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73
+ let cppFunctionName = "isValidTileTypeElementType";
74
+ }
75
+
76
+ def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
77
+ let summary = "AMX 2D tile to be used by AMX opertaions.";
78
+
79
+ let description = [{
80
+ This type is used to represent values in AMX tile registers. All AMX operations
81
+ work on AMX tiles and these tiles cannot be used in other operations directly.
82
+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
83
+ element type for IR verification and lowering to LLVMIR dialect.
84
+ }];
85
+
86
+ let parameters = (ins
87
+ ArrayRefParameter<"int64_t">:$shape,
88
+ AMX_TileTypeElementType:$elementType
89
+ );
90
+
91
+ let builders = [
92
+ TypeBuilderWithInferredContext<(ins
93
+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
94
+ return $_get(elementType.getContext(), shape, elementType);
95
+ }]>
96
+ ];
97
+
98
+ let extraClassDeclaration = [{
99
+ /// Returns if this type is ranked (always true).
100
+ bool hasRank() const { return true; }
101
+
102
+ /// Clone this tile type with the given shape and element type. If the
103
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
104
+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
105
+ Type elementType) const {
106
+ return get(shape.value_or(getShape()), elementType);
107
+ }
108
+ }];
109
+
110
+ let hasCustomAssemblyFormat = 1;
111
+ let skipDefaultBuilders = 1;
112
+ }
113
+
114
+ def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
115
+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
116
+
117
+ class AMXTileOf<list<Type> allowedTypes> :
118
+ ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
119
+ "::mlir::amx::TileType">;
120
+
121
+ def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122
+
123
+ def AMXTileF32 : AMXTileOf<[F32]>;
124
+
125
+ def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126
+
127
+ def AMXTileI32 : AMXTileOf<[I32]>;
128
+
129
+ def AMXTileI8 : AMXTileOf<[I8]>;
130
+
60
131
//===----------------------------------------------------------------------===//
61
132
// AMX Op and IntrOp definitions.
62
133
//===----------------------------------------------------------------------===//
@@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
88
159
Example:
89
160
90
161
```mlir
91
- %0 = amx.tile_zero : vector <16x16xbf16>
162
+ %0 = amx.tile_zero : !amx.tile <16x16xbf16>
92
163
```
93
164
}];
94
- let results = (outs
95
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
165
+ let results = (outs AnyAMXTile:$res);
96
166
let extraClassDeclaration = [{
97
- VectorType getVectorType () {
98
- return ::llvm::cast<VectorType >(getRes().getType());
167
+ TileType getTileType () {
168
+ return ::llvm::cast<TileType >(getRes().getType());
99
169
}
100
170
}];
101
- let assemblyFormat = "attr-dict `:` type($res)";
171
+ let assemblyFormat = "attr-dict `:` qualified( type($res) )";
102
172
let hasVerifier = 1;
103
173
}
104
174
@@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
117
187
Example:
118
188
119
189
```mlir
120
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector <16x64xi8>
190
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile <16x64xi8>
121
191
```
122
192
}];
123
193
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
124
194
Variadic<Index>:$indices);
125
- let results = (outs
126
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
195
+ let results = (outs AnyAMXTile:$res);
127
196
let extraClassDeclaration = [{
128
197
MemRefType getMemRefType() {
129
198
return ::llvm::cast<MemRefType>(getBase().getType());
130
199
}
131
- VectorType getVectorType () {
132
- return ::llvm::cast<VectorType >(getRes().getType());
200
+ TileType getTileType () {
201
+ return ::llvm::cast<TileType >(getRes().getType());
133
202
}
134
203
}];
135
204
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
136
- "type($base) `into` type($res)";
205
+ "type($base) `into` qualified( type($res) )";
137
206
let hasVerifier = 1;
138
207
}
139
208
@@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
148
217
Example:
149
218
150
219
```mlir
151
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector <16x64xi8>
220
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile <16x64xi8>
152
221
```
153
222
}];
154
223
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
155
224
Variadic<Index>:$indices,
156
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]> :$val);
225
+ AnyAMXTile :$val);
157
226
let extraClassDeclaration = [{
158
227
MemRefType getMemRefType() {
159
228
return ::llvm::cast<MemRefType>(getBase().getType());
160
229
}
161
- VectorType getVectorType () {
162
- return ::llvm::cast<VectorType >(getVal().getType());
230
+ TileType getTileType () {
231
+ return ::llvm::cast<TileType >(getVal().getType());
163
232
}
164
233
}];
165
234
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
166
- "type($base) `,` type($val)";
235
+ "type($base) `,` qualified( type($val) )";
167
236
let hasVerifier = 1;
168
237
}
169
238
@@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
184
253
185
254
```mlir
186
255
%0 = amx.tile_mulf %a, %b, %c
187
- : vector <16x32xbf16>, vector <16x32xbf16>, vector <16x16xf32>
256
+ : !amx.tile <16x32xbf16>, !amx.tile <16x32xbf16>, !amx.tile <16x16xf32>
188
257
```
189
258
}];
190
- let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]> :$lhs,
191
- VectorOfRankAndType<[2], [F32, BF16]> :$rhs,
192
- VectorOfRankAndType<[2], [F32, BF16]> :$acc);
193
- let results = (outs VectorOfRankAndType<[2], [F32, BF16]> :$res);
259
+ let arguments = (ins AMXTileF16OrBF16 :$lhs,
260
+ AMXTileF16OrBF16 :$rhs,
261
+ AMXTileF32 :$acc);
262
+ let results = (outs AMXTileF32 :$res);
194
263
let extraClassDeclaration = [{
195
- VectorType getLhsVectorType () {
196
- return ::llvm::cast<VectorType >(getLhs().getType());
264
+ TileType getLhsTileType () {
265
+ return ::llvm::cast<TileType >(getLhs().getType());
197
266
}
198
- VectorType getRhsVectorType () {
199
- return ::llvm::cast<VectorType >(getRhs().getType());
267
+ TileType getRhsTileType () {
268
+ return ::llvm::cast<TileType >(getRhs().getType());
200
269
}
201
- VectorType getVectorType () {
202
- return ::llvm::cast<VectorType >(getRes().getType());
270
+ TileType getTileType () {
271
+ return ::llvm::cast<TileType >(getRes().getType());
203
272
}
204
273
}];
205
274
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
206
- "type($lhs) `,` type($rhs) `,` type($acc) ";
275
+ "qualified(type($lhs)) `,` qualified(type($rhs))"
276
+ " `,` qualified(type($acc)) ";
207
277
let hasVerifier = 1;
208
278
}
209
279
@@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [
223
293
224
294
```mlir
225
295
%0 = amx.tile_muli %a zext, %b zext, %c
226
- : vector <16x64xi8>, vector <16x64xi8>, vector <16x16xi32>
296
+ : !amx.tile <16x64xi8>, !amx.tile <16x64xi8>, !amx.tile <16x16xi32>
227
297
```
228
298
}];
229
- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]> :$lhs,
230
- VectorOfRankAndType<[2], [I32, I8]> :$rhs,
231
- VectorOfRankAndType<[2], [I32, I8]> :$acc,
299
+ let arguments = (ins AMXTileI8 :$lhs,
300
+ AMXTileI8 :$rhs,
301
+ AMXTileI32 :$acc,
232
302
UnitAttr:$isZextLhs,
233
303
UnitAttr:$isZextRhs
234
304
);
235
- let results = (outs VectorOfRankAndType<[2], [I32, I8]> :$res);
305
+ let results = (outs AMXTileI32 :$res);
236
306
let extraClassDeclaration = [{
237
- VectorType getLhsVectorType () {
238
- return ::llvm::cast<VectorType >(getLhs().getType());
307
+ TileType getLhsTileType () {
308
+ return ::llvm::cast<TileType >(getLhs().getType());
239
309
}
240
- VectorType getRhsVectorType () {
241
- return ::llvm::cast<VectorType >(getRhs().getType());
310
+ TileType getRhsTileType () {
311
+ return ::llvm::cast<TileType >(getRhs().getType());
242
312
}
243
- VectorType getVectorType () {
244
- return ::llvm::cast<VectorType >(getRes().getType());
313
+ TileType getTileType () {
314
+ return ::llvm::cast<TileType >(getRes().getType());
245
315
}
246
316
}];
247
317
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
248
- "type($lhs) `,` type($rhs) `,` type($acc) ";
318
+ "qualified( type($lhs)) `,` qualified( type($rhs)) `,` qualified( type($acc) ) ";
249
319
let hasVerifier = 1;
250
320
}
251
321
@@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
286
356
AnyInteger,
287
357
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
288
358
359
+ // Dot product of f16 tiles into f32 tile.
360
+ def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361
+ Arguments<(ins AnyInteger,
362
+ AnyInteger,
363
+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364
+
289
365
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
290
366
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
291
367
Arguments<(ins AnyInteger,
0 commit comments