@@ -460,3 +460,33 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
460
460
vector.transfer_write %cast , %arg3 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <16 x16 xf32 >, memref <16 x16 xf32 >
461
461
return
462
462
}
463
+
464
+ // -----
465
+
466
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
467
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
468
+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
469
+
470
+ // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
471
+ // CHECK-LABEL: func @fold_transpose_into_transfer_read(
472
+ // CHECK-SAME: %[[ALLOC:.+]]: memref<64x128xf16>
473
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
474
+ // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
475
+ // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true], permutation_map = #[[$MAP]]}
476
+ // CHECK: %[[EXTF1:.+]] = arith.extf %[[READ]]
477
+ // CHECK-NOT: vector.transpose
478
+ // CHECK: %[[RESULT:.+]] = vector.contract
479
+ func.func @fold_transpose_into_transfer_read (%alloc: memref <64 x128 xf16 >, %vector: vector <32 x128 xf16 >, %alloc2: memref <32 x64 xf32 >) {
480
+ %c0 = arith.constant 0 : index
481
+ %cst = arith.constant 0.000000e+00 : f16
482
+ %init = arith.constant dense <0.000000e+00 > : vector <32 x64 xf32 >
483
+ %0 = vector.transfer_read %alloc [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <64 x128 xf16 >, vector <64 x128 xf16 >
484
+ %1 = arith.extf %0 : vector <64 x128 xf16 > to vector <64 x128 xf32 >
485
+ %2 = arith.extf %vector : vector <32 x128 xf16 > to vector <32 x128 xf32 >
486
+ %3 = vector.transpose %1 , [1 , 0 ] : vector <64 x128 xf32 > to vector <128 x64 xf32 >
487
+ %4 = vector.contract {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %2 , %3 , %init : vector <32 x128 xf32 >, vector <128 x64 xf32 > into vector <32 x64 xf32 >
488
+ vector.transfer_write %4 , %alloc2 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <32 x64 xf32 >, memref <32 x64 xf32 >
489
+ return
490
+ }
491
+
492
+ // -----
0 commit comments