@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
73
73
let cppFunctionName = "isValidTileTypeElementType";
74
74
}
75
75
76
- def AMX_TileType : AMX_Type<"Tile", "amx. tile", [ShapedTypeInterface, ValueSemantics]> {
76
+ def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
77
77
let summary = "AMX 2D tile to be used by AMX opertaions.";
78
78
79
79
let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
111
111
let skipDefaultBuilders = 1;
112
112
}
113
113
114
- def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
115
-
116
- def IsAMX2DTilePred : And<[IsAMXTilePred,
114
+ def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
117
115
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
118
116
119
- class AMX2DTileOf <list<Type> allowedTypes> :
120
- ShapedContainerType<allowedTypes, IsAMX2DTilePred , "tile",
117
+ class AMXTileOf <list<Type> allowedTypes> :
118
+ ShapedContainerType<allowedTypes, IsAMXTilePred , "tile",
121
119
"::mlir::amx::TileType">;
122
120
121
+ def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122
+
123
+ def AMXTileF32 : AMXTileOf<[F32]>;
124
+
125
+ def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126
+
127
+ def AMXTileI32 : AMXTileOf<[I32]>;
128
+
129
+ def AMXTileI8 : AMXTileOf<[I8]>;
130
+
123
131
//===----------------------------------------------------------------------===//
124
132
// AMX Op and IntrOp definitions.
125
133
//===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
151
159
Example:
152
160
153
161
```mlir
154
- %0 = amx.tile_zero : <16x16xbf16>
162
+ %0 = amx.tile_zero : !amx.tile <16x16xbf16>
155
163
```
156
164
}];
157
- let results = (outs
158
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
165
+ let results = (outs AnyAMXTile:$res);
159
166
let extraClassDeclaration = [{
160
167
TileType getTileType() {
161
168
return ::llvm::cast<TileType>(getRes().getType());
162
169
}
163
170
}];
164
- let assemblyFormat = "attr-dict `:` type($res)";
171
+ let assemblyFormat = "attr-dict `:` qualified( type($res) )";
165
172
let hasVerifier = 1;
166
173
}
167
174
@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
180
187
Example:
181
188
182
189
```mlir
183
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
190
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile <16x64xi8>
184
191
```
185
192
}];
186
193
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
187
194
Variadic<Index>:$indices);
188
- let results = (outs
189
- AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
195
+ let results = (outs AnyAMXTile:$res);
190
196
let extraClassDeclaration = [{
191
197
MemRefType getMemRefType() {
192
198
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
196
202
}
197
203
}];
198
204
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
199
- "type($base) `into` type($res)";
205
+ "type($base) `into` qualified( type($res) )";
200
206
let hasVerifier = 1;
201
207
}
202
208
@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
211
217
Example:
212
218
213
219
```mlir
214
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
220
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile <16x64xi8>
215
221
```
216
222
}];
217
223
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
218
224
Variadic<Index>:$indices,
219
- AMX2DTileOf<[F32, F16, BF16, I32, I8]> :$val);
225
+ AnyAMXTile :$val);
220
226
let extraClassDeclaration = [{
221
227
MemRefType getMemRefType() {
222
228
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
226
232
}
227
233
}];
228
234
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
229
- "type($base) `,` type($val)";
235
+ "type($base) `,` qualified( type($val) )";
230
236
let hasVerifier = 1;
231
237
}
232
238
@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
246
252
Example:
247
253
248
254
```mlir
249
- %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
255
+ %0 = amx.tile_mulf %a, %b, %c
256
+ : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
250
257
```
251
258
}];
252
- let arguments = (ins AMX2DTileOf<[F16, BF16]> :$lhs,
253
- AMX2DTileOf<[F16, BF16]> :$rhs,
254
- AMX2DTileOf<[F32]> :$acc);
255
- let results = (outs AMX2DTileOf<[F32]> :$res);
259
+ let arguments = (ins AMXTileF16OrBF16 :$lhs,
260
+ AMXTileF16OrBF16 :$rhs,
261
+ AMXTileF32 :$acc);
262
+ let results = (outs AMXTileF32 :$res);
256
263
let extraClassDeclaration = [{
257
264
TileType getLhsTileType() {
258
265
return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
265
272
}
266
273
}];
267
274
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
268
- "type($lhs) `,` type($rhs) `,` type($acc) ";
275
+ "qualified(type($lhs)) `,` qualified(type($rhs))"
276
+ " `,` qualified(type($acc)) ";
269
277
let hasVerifier = 1;
270
278
}
271
279
@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
284
292
Example:
285
293
286
294
```mlir
287
- %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
295
+ %0 = amx.tile_muli %a zext, %b zext, %c
296
+ : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
288
297
```
289
298
}];
290
- let arguments = (ins AMX2DTileOf<[I8]> :$lhs,
291
- AMX2DTileOf<[I8]> :$rhs,
292
- AMX2DTileOf<[I32]> :$acc,
299
+ let arguments = (ins AMXTileI8 :$lhs,
300
+ AMXTileI8 :$rhs,
301
+ AMXTileI32 :$acc,
293
302
UnitAttr:$isZextLhs,
294
303
UnitAttr:$isZextRhs
295
304
);
296
- let results = (outs AMX2DTileOf<[I32]> :$res);
305
+ let results = (outs AMXTileI32 :$res);
297
306
let extraClassDeclaration = [{
298
307
TileType getLhsTileType() {
299
308
return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
306
315
}
307
316
}];
308
317
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
309
- "type($lhs) `,` type($rhs) `,` type($acc) ";
318
+ "qualified( type($lhs)) `,` qualified( type($rhs)) `,` qualified( type($acc) ) ";
310
319
let hasVerifier = 1;
311
320
}
312
321
0 commit comments