Skip to content

Commit 3a07dfc

Browse files
committed
Fix review comments.
Signed-off-by: Ilya Enkovich <[email protected]>
1 parent 2a76b2d commit 3a07dfc

File tree

6 files changed

+97
-86
lines changed

6 files changed

+97
-86
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ class LLVMTypeConverter : public TypeConverter {
259259
/// Convert a 1D vector type into an LLVM vector type.
260260
FailureOr<Type> convertVectorType(VectorType type) const;
261261

262-
/// Convert AMX tile type x86_amx type.
262+
/// Convert an AMX tile type to the x86_amx type.
263263
Type convertAMXTileType(amx::TileType type) const;
264264

265265
/// Options for customizing the llvm lowering.

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
7373
let cppFunctionName = "isValidTileTypeElementType";
7474
}
7575

76-
def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
76+
def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
7777
let summary = "AMX 2D tile to be used by AMX opertaions.";
7878

7979
let description = [{
@@ -111,15 +111,23 @@ def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSeman
111111
let skipDefaultBuilders = 1;
112112
}
113113

114-
def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
115-
116-
def IsAMX2DTilePred : And<[IsAMXTilePred,
114+
def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
117115
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
118116

119-
class AMX2DTileOf<list<Type> allowedTypes> :
120-
ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
117+
class AMXTileOf<list<Type> allowedTypes> :
118+
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
121119
"::mlir::amx::TileType">;
122120

121+
def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
122+
123+
def AMXTileF32 : AMXTileOf<[F32]>;
124+
125+
def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
126+
127+
def AMXTileI32 : AMXTileOf<[I32]>;
128+
129+
def AMXTileI8 : AMXTileOf<[I8]>;
130+
123131
//===----------------------------------------------------------------------===//
124132
// AMX Op and IntrOp definitions.
125133
//===----------------------------------------------------------------------===//
@@ -151,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
151159
Example:
152160

153161
```mlir
154-
%0 = amx.tile_zero : <16x16xbf16>
162+
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
155163
```
156164
}];
157-
let results = (outs
158-
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
165+
let results = (outs AnyAMXTile:$res);
159166
let extraClassDeclaration = [{
160167
TileType getTileType() {
161168
return ::llvm::cast<TileType>(getRes().getType());
162169
}
163170
}];
164-
let assemblyFormat = "attr-dict `:` type($res)";
171+
let assemblyFormat = "attr-dict `:` qualified(type($res))";
165172
let hasVerifier = 1;
166173
}
167174

@@ -180,13 +187,12 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
180187
Example:
181188

182189
```mlir
183-
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
190+
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
184191
```
185192
}];
186193
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
187194
Variadic<Index>:$indices);
188-
let results = (outs
189-
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
195+
let results = (outs AnyAMXTile:$res);
190196
let extraClassDeclaration = [{
191197
MemRefType getMemRefType() {
192198
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -196,7 +202,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
196202
}
197203
}];
198204
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
199-
"type($base) `into` type($res)";
205+
"type($base) `into` qualified(type($res))";
200206
let hasVerifier = 1;
201207
}
202208

@@ -211,12 +217,12 @@ def TileStoreOp : AMX_Op<"tile_store"> {
211217
Example:
212218

213219
```mlir
214-
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
220+
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
215221
```
216222
}];
217223
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
218224
Variadic<Index>:$indices,
219-
AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
225+
AnyAMXTile:$val);
220226
let extraClassDeclaration = [{
221227
MemRefType getMemRefType() {
222228
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -226,7 +232,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
226232
}
227233
}];
228234
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
229-
"type($base) `,` type($val)";
235+
"type($base) `,` qualified(type($val))";
230236
let hasVerifier = 1;
231237
}
232238

@@ -246,13 +252,14 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
246252
Example:
247253

248254
```mlir
249-
%0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
255+
%0 = amx.tile_mulf %a, %b, %c
256+
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
250257
```
251258
}];
252-
let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
253-
AMX2DTileOf<[F16, BF16]>:$rhs,
254-
AMX2DTileOf<[F32]>:$acc);
255-
let results = (outs AMX2DTileOf<[F32]>:$res);
259+
let arguments = (ins AMXTileF16OrBF16:$lhs,
260+
AMXTileF16OrBF16:$rhs,
261+
AMXTileF32:$acc);
262+
let results = (outs AMXTileF32:$res);
256263
let extraClassDeclaration = [{
257264
TileType getLhsTileType() {
258265
return ::llvm::cast<TileType>(getLhs().getType());
@@ -265,7 +272,8 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
265272
}
266273
}];
267274
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
268-
"type($lhs) `,` type($rhs) `,` type($acc) ";
275+
"qualified(type($lhs)) `,` qualified(type($rhs))"
276+
" `,` qualified(type($acc)) ";
269277
let hasVerifier = 1;
270278
}
271279

@@ -284,16 +292,17 @@ def TileMulIOp : AMX_Op<"tile_muli", [
284292
Example:
285293

286294
```mlir
287-
%0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
295+
%0 = amx.tile_muli %a zext, %b zext, %c
296+
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
288297
```
289298
}];
290-
let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
291-
AMX2DTileOf<[I8]>:$rhs,
292-
AMX2DTileOf<[I32]>:$acc,
299+
let arguments = (ins AMXTileI8:$lhs,
300+
AMXTileI8:$rhs,
301+
AMXTileI32:$acc,
293302
UnitAttr:$isZextLhs,
294303
UnitAttr:$isZextRhs
295304
);
296-
let results = (outs AMX2DTileOf<[I32]>:$res);
305+
let results = (outs AMXTileI32:$res);
297306
let extraClassDeclaration = [{
298307
TileType getLhsTileType() {
299308
return ::llvm::cast<TileType>(getLhs().getType());
@@ -306,7 +315,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [
306315
}
307316
}];
308317
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
309-
"type($lhs) `,` type($rhs) `,` type($acc) ";
318+
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
310319
let hasVerifier = 1;
311320
}
312321

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,12 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
158158
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
159159
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
160160
adaptor.getLhs(), adaptor.getRhs());
161-
else
161+
else if (aType.getElementType().isF16())
162162
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpfp16ps>(
163163
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
164164
adaptor.getLhs(), adaptor.getRhs());
165+
else
166+
llvm_unreachable("Unexpected element type for amx.mulf");
165167
return success();
166168
}
167169
};

mlir/test/Dialect/AMX/invalid.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,45 @@
44

55
func.func @rowheight() {
66
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
7-
%0 = amx.tile_zero : <17x16xbf16>
7+
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
88
}
99

1010
// -----
1111

1212
func.func @colwidth() {
1313
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
14-
%0 = amx.tile_zero : <16x65xi8>
14+
%0 = amx.tile_zero : !amx.tile<16x65xi8>
1515
}
1616

1717
// -----
1818

1919
func.func @col4bytemultiple() {
2020
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
21-
%0 = amx.tile_zero : <16x5xi8>
21+
%0 = amx.tile_zero : !amx.tile<16x5xi8>
2222
}
2323

2424
// -----
2525

2626
func.func @memtilesize(%arg0: memref<?x?xf32>) {
2727
%0 = arith.constant 0 : index
2828
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
29-
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into <16x17xf32>
29+
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
3030
}
3131

3232
// -----
3333

3434
func.func @memindexsize(%arg0: memref<?x?xf32>) {
3535
%0 = arith.constant 0 : index
3636
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
37-
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into <16x16xf32>
37+
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
3838
}
3939

4040
// -----
4141

4242
func.func @multsize() {
43-
%0 = amx.tile_zero : <8x8xbf16>
44-
%1 = amx.tile_zero : <8x8xbf16>
45-
%2 = amx.tile_zero : <4x4xf32>
43+
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
44+
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
45+
%2 = amx.tile_zero : !amx.tile<4x4xf32>
4646
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
47-
%3 = amx.tile_mulf %0, %1, %2 : <8x8xbf16>, <8x8xbf16>, <4x4xf32>
47+
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
4848
}

mlir/test/Dialect/AMX/legalize-for-llvm.mlir

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
// CHECK: amx.tilestored64
1515
func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
1616
%0 = arith.constant 0 : index
17-
%1 = amx.tile_zero : <16x64xi8>
18-
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
19-
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into <16x16xi32>
20-
%4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
21-
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
22-
%5 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
23-
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, <16x16xi32>
24-
%6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
25-
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, <16x16xi32>
26-
%7 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
27-
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, <16x16xi32>
17+
%1 = amx.tile_zero : !amx.tile<16x64xi8>
18+
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
19+
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
20+
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
21+
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
22+
%5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
23+
amx.tile_store %arg1[%0, %0], %5 : memref<?x?xi32>, !amx.tile<16x16xi32>
24+
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
25+
amx.tile_store %arg1[%0, %0], %6 : memref<?x?xi32>, !amx.tile<16x16xi32>
26+
%7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
27+
amx.tile_store %arg1[%0, %0], %7 : memref<?x?xi32>, !amx.tile<16x16xi32>
2828
return
2929
}
3030

@@ -36,11 +36,11 @@ func.func @muli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi32>) {
3636
// CHECK: amx.tilestored64
3737
func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
3838
%0 = arith.constant 0 : index
39-
%1 = amx.tile_zero : <16x32xbf16>
40-
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
41-
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
42-
%4 = amx.tile_mulf %1, %2, %3 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
43-
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
39+
%1 = amx.tile_zero : !amx.tile<16x32xbf16>
40+
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
41+
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
42+
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
43+
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
4444
return
4545
}
4646

@@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
5252
// CHECK: amx.tilestored64
5353
func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
5454
%0 = arith.constant 0 : index
55-
%1 = amx.tile_zero : <16x32xf16>
56-
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into <16x32xf16>
57-
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
58-
%4 = amx.tile_mulf %1, %2, %3 : <16x32xf16>, <16x32xf16>, <16x16xf32>
59-
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, <16x16xf32>
55+
%1 = amx.tile_zero : !amx.tile<16x32xf16>
56+
%2 = amx.tile_load %arg0[%0, %0] : memref<?x?xf16> into !amx.tile<16x32xf16>
57+
%3 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
58+
%4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32>
59+
amx.tile_store %arg1[%0, %0], %4 : memref<?x?xf32>, !amx.tile<16x16xf32>
6060
return
6161
}
6262

mlir/test/Dialect/AMX/roundtrip.mlir

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
11
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
22

33
// CHECK-LABEL: tzero
4-
// CHECK: amx.tile_zero : <16x16xbf16>
5-
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, <16x16xbf16>
4+
// CHECK: amx.tile_zero : !amx.tile<16x16xbf16>
5+
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref<?x?xbf16>, !amx.tile<16x16xbf16>
66
func.func @tzero(%arg0: memref<?x?xbf16>) {
77
%0 = arith.constant 0 : index
8-
%1 = amx.tile_zero : <16x16xbf16>
9-
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, <16x16xbf16>
8+
%1 = amx.tile_zero : !amx.tile<16x16xbf16>
9+
amx.tile_store %arg0[%0, %0], %1 : memref<?x?xbf16>, !amx.tile<16x16xbf16>
1010
return
1111
}
1212

1313
// CHECK-LABEL: tmulf
14-
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into <16x32xbf16>
15-
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into <16x16xf32>
16-
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
17-
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, <16x16xf32>
14+
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
15+
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32> into !amx.tile<16x16xf32>
16+
// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
17+
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xf32>, !amx.tile<16x16xf32>
1818
func.func @tmulf(%arg0: memref<?x?xbf16>, %arg1: memref<?x?xf32>) {
1919
%0 = arith.constant 0 : index
20-
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into <16x32xbf16>
21-
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into <16x16xf32>
22-
%3 = amx.tile_mulf %1, %1, %2 : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
23-
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, <16x16xf32>
20+
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xbf16> into !amx.tile<16x32xbf16>
21+
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xf32> into !amx.tile<16x16xf32>
22+
%3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
23+
amx.tile_store %arg1[%0, %0], %3 : memref<?x?xf32>, !amx.tile<16x16xf32>
2424
return
2525
}
2626

2727
// CHECK-LABEL: tmuli
28-
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
29-
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into <16x64xi8>
30-
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into <16x16xi32>
31-
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : <16x64xi8>, <16x64xi8>, <16x16xi32>
32-
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, <16x16xi32>
28+
// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
29+
// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi8> into !amx.tile<16x64xi8>
30+
// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xi32> into !amx.tile<16x16xi32>
31+
// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
32+
// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref<?x?xi32>, !amx.tile<16x16xi32>
3333
// Verify the parsing/printing of the sign-extension annotation.
3434
// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}}
3535
// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}}
3636
// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}}
3737
func.func @tmuli(%arg0: memref<?x?xi8>, %arg1: memref<?x?xi8>, %arg2: memref<?x?xi32>) {
3838
%0 = arith.constant 0 : index
39-
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into <16x64xi8>
40-
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into <16x64xi8>
41-
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into <16x16xi32>
42-
%4 = amx.tile_muli %1 zext, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
43-
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, <16x16xi32>
39+
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
40+
%2 = amx.tile_load %arg1[%0, %0] : memref<?x?xi8> into !amx.tile<16x64xi8>
41+
%3 = amx.tile_load %arg2[%0, %0] : memref<?x?xi32> into !amx.tile<16x16xi32>
42+
%4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
43+
amx.tile_store %arg2[%0, %0], %4 : memref<?x?xi32>, !amx.tile<16x16xi32>
4444
// Verify the various `zext` combinations.
45-
%5 = amx.tile_muli %1, %2 zext, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
46-
%6 = amx.tile_muli %1 zext, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
47-
%7 = amx.tile_muli %1, %2, %3 : <16x64xi8>, <16x64xi8>, <16x16xi32>
45+
%5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
46+
%6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
47+
%7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
4848
return
4949
}

0 commit comments

Comments
 (0)