@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
490
490
}
491
491
492
492
// -----
493
+
494
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
495
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>
496
+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
497
+
498
+ // CHECK-LABEL: func @cast_f16_to_f32_read
499
+ // CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
500
+ // CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
501
+ // CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
502
+ // CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
503
+ // CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
504
+ // CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
505
+ // CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
506
+ func.func @cast_f16_to_f32_read (%arg0: memref <16 x16 xf16 >, %arg1: memref <16 x16 xf16 >, %arg2: memref <16 x16 xf16 >, %arg3: memref <16 x16 xf32 >) {
507
+ %c0 = arith.constant 0 : index
508
+ %cst = arith.constant 0.000000e+00 : f16
509
+ %A = vector.transfer_read %arg0 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
510
+ %B = vector.transfer_read %arg1 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
511
+ %C = vector.transfer_read %arg2 [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <16 x16 xf16 >, vector <16 x16 xf16 >
512
+ %Aext = arith.extf %A : vector <16 x16 xf16 > to vector <16 x16 xf32 >
513
+ %Bext = arith.extf %B : vector <16 x16 xf16 > to vector <16 x16 xf32 >
514
+ %Cext = arith.extf %C : vector <16 x16 xf16 > to vector <16 x16 xf32 >
515
+ %D = vector.contract {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >}
516
+ %Aext , %Bext , %Cext : vector <16 x16 xf32 >, vector <16 x16 xf32 > into vector <16 x16 xf32 >
517
+ vector.transfer_write %D , %arg3 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <16 x16 xf32 >, memref <16 x16 xf32 >
518
+ return
519
+ }
0 commit comments