@@ -79,11 +79,11 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
79
79
// CHECK: "amx.tilestored64"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
80
80
func.func @strides (%arg0: memref <16 x32 xbf16 >, %arg1: memref <16 x32 xbf16 , strided <[64 , 1 ]>>, %arg2: memref <16 x32 xbf16 , strided <[?, 1 ]>>) {
81
81
%0 = arith.constant 0 : index
82
- %1 = amx.tile_load %arg0 [%0 , %0 ] : memref <16 x32 xbf16 > into vector <16 x32 xbf16 >
83
- %2 = amx.tile_load %arg1 [%0 , %0 ] : memref <16 x32 xbf16 , strided <[64 , 1 ]>> into vector <16 x32 xbf16 >
84
- %3 = amx.tile_load %arg2 [%0 , %0 ] : memref <16 x32 xbf16 , strided <[?, 1 ]>> into vector <16 x32 xbf16 >
85
- amx.tile_store %arg0 [%0 , %0 ], %3 : memref <16 x32 xbf16 >, vector <16 x32 xbf16 >
86
- amx.tile_store %arg1 [%0 , %0 ], %1 : memref <16 x32 xbf16 , strided <[64 , 1 ]>>, vector <16 x32 xbf16 >
87
- amx.tile_store %arg2 [%0 , %0 ], %2 : memref <16 x32 xbf16 , strided <[?, 1 ]>>, vector <16 x32 xbf16 >
82
+ %1 = amx.tile_load %arg0 [%0 , %0 ] : memref <16 x32 xbf16 > into !amx.tile <16 x32 xbf16 >
83
+ %2 = amx.tile_load %arg1 [%0 , %0 ] : memref <16 x32 xbf16 , strided <[64 , 1 ]>> into !amx.tile <16 x32 xbf16 >
84
+ %3 = amx.tile_load %arg2 [%0 , %0 ] : memref <16 x32 xbf16 , strided <[?, 1 ]>> into !amx.tile <16 x32 xbf16 >
85
+ amx.tile_store %arg0 [%0 , %0 ], %3 : memref <16 x32 xbf16 >, !amx.tile <16 x32 xbf16 >
86
+ amx.tile_store %arg1 [%0 , %0 ], %1 : memref <16 x32 xbf16 , strided <[64 , 1 ]>>, !amx.tile <16 x32 xbf16 >
87
+ amx.tile_store %arg2 [%0 , %0 ], %2 : memref <16 x32 xbf16 , strided <[?, 1 ]>>, !amx.tile <16 x32 xbf16 >
88
88
return
89
89
}
0 commit comments