10
10
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
11
11
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
12
12
// CHECK-NEXT: return %[[R]] : vector<64x64xf32>
13
- func.func @fold_arith_extf_into_contract (%arg0: vector <64 x64 xf16 >, %arg1: vector <64 x64 xf16 >, %arg2: vector <64 x64 xf32 >) -> vector <64 x64 xf32 > {
13
+ func.func @fold_arith_extf_into_contract (
14
+ %arg0: vector <64 x64 xf16 >,
15
+ %arg1: vector <64 x64 xf16 >,
16
+ %arg2: vector <64 x64 xf32 >) -> vector <64 x64 xf32 > {
14
17
%lhs_f32 = arith.extf %arg0 : vector <64 x64 xf16 > to vector <64 x64 xf32 >
15
18
%rhs_f32 = arith.extf %arg1 : vector <64 x64 xf16 > to vector <64 x64 xf32 >
16
- %result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs_f32 , %rhs_f32 , %arg2 : vector <64 x64 xf32 >, vector <64 x64 xf32 > into vector <64 x64 xf32 >
19
+ %result = vector.contract {
20
+ indexing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>],
21
+ iterator_types = [" parallel" , " parallel" , " reduction" ],
22
+ kind = #vector.kind <add >}
23
+ %lhs_f32 , %rhs_f32 , %arg2 : vector <64 x64 xf32 >, vector <64 x64 xf32 > into vector <64 x64 xf32 >
17
24
return %result : vector <64 x64 xf32 >
18
- }
25
+ }
26
+
27
+ // -----
28
+
29
+ // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
30
+ // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
31
+ // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
32
+ // CHECK-LABEL: func.func @fold_arith_extf_into_contract_scalable
33
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<[64]x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<[64]x64xf32>)
34
+ // CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
35
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
36
+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<[64]x64xf16>, vector<64x64xf16> into vector<[64]x64xf32>
37
+ // CHECK-NEXT: return %[[R]] : vector<[64]x64xf32>
38
+ func.func @fold_arith_extf_into_contract_scalable (
39
+ %arg0: vector <[64 ]x64 xf16 >,
40
+ %arg1: vector <64 x64 xf16 >,
41
+ %arg2: vector <[64 ]x64 xf32 >) -> vector <[64 ]x64 xf32 > {
42
+ %lhs_f32 = arith.extf %arg0 : vector <[64 ]x64 xf16 > to vector <[64 ]x64 xf32 >
43
+ %rhs_f32 = arith.extf %arg1 : vector <64 x64 xf16 > to vector <64 x64 xf32 >
44
+ %result = vector.contract {
45
+ indexing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>],
46
+ iterator_types = [" parallel" , " parallel" , " reduction" ],
47
+ kind = #vector.kind <add >}
48
+ %lhs_f32 , %rhs_f32 , %arg2 : vector <[64 ]x64 xf32 >, vector <64 x64 xf32 > into vector <[64 ]x64 xf32 >
49
+ return %result : vector <[64 ]x64 xf32 >
50
+ }
0 commit comments