@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
30
30
}
31
31
32
32
// -----
33
+ // CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
34
+ // CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
35
+ // CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
36
+
37
+ // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
38
+ // CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
39
+ // CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
40
+ // CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
41
+ // CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
42
+ // CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
43
+ // CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
44
+ // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
45
+ // CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
46
+ // CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
47
+ // CHECK: return %[[RES]] : vector<1x16x16xf32>
48
+
49
+ #contraction_accesses0 = [
50
+ affine_map <(l , i , j , k ) -> (l , i , k )>,
51
+ affine_map <(l , i , j , k ) -> (l , k , j )>,
52
+ affine_map <(l , i , j , k ) -> (l , i , j )>
53
+ ]
54
+ #contraction_trait0 = {
55
+ indexing_maps = #contraction_accesses0 ,
56
+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
57
+ }
58
+
59
+ func.func @cast_away_contraction_leading_one_dim_under_const_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
60
+ %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
61
+ %0 = vector.mask %mask {
62
+ vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
63
+ } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
64
+ return %0 : vector <1 x16 x16 xf32 >
65
+ }
66
+
67
+ // -----
68
+ // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
69
+ // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
70
+ // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
71
+
72
+ // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
73
+ // CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
74
+ // CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
75
+ // CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
76
+ // CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
77
+ // CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
78
+ // CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
79
+ // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
80
+ // CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
81
+ // CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
82
+ // CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
83
+
84
+ #contraction_accesses0 = [
85
+ affine_map <(l , i , j , k ) -> (l , i , k )>,
86
+ affine_map <(l , i , j , k ) -> (l , k , j )>,
87
+ affine_map <(l , i , j , k ) -> (l , i , j )>
88
+ ]
89
+ #contraction_trait0 = {
90
+ indexing_maps = #contraction_accesses0 ,
91
+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
92
+ }
93
+
94
+ func.func @cast_away_contraction_leading_one_dim_under_mask (
95
+ %arg0: vector <1 x16 x8 xf32 >,
96
+ %arg1: vector <1 x8 x16 xf32 >,
97
+ %arg2: vector <1 x16 x16 xf32 >,
98
+ %mask: vector <1 x16 x16 x8 xi1 >) -> vector <1 x16 x16 xf32 > {
99
+ %0 = vector.mask %mask {
100
+ vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
101
+ } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
102
+ return %0: vector <1 x16 x16 xf32 >
103
+ }
104
+
105
+ // -----
106
+
33
107
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
34
108
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
35
109
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164
238
return %0: vector <1 x1 x2 x16 xf32 >
165
239
}
166
240
167
- // -----
168
-
169
- // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170
- // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171
- // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172
-
173
- // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174
- // CHECK: %[[MASK:.+]] = vector.constant_mask
175
- // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176
- // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177
- // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178
- // CHECK: return %[[RET]] : vector<1x16x16xf32>
179
-
180
- #contraction_accesses0 = [
181
- affine_map <(l , i , j , k ) -> (l , i , k )>,
182
- affine_map <(l , i , j , k ) -> (l , k , j )>,
183
- affine_map <(l , i , j , k ) -> (l , i , j )>
184
- ]
185
- #contraction_trait0 = {
186
- indexing_maps = #contraction_accesses0 ,
187
- iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
188
- }
189
-
190
- func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
191
- %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
192
- %0 = vector.mask %mask {
193
- vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
194
- } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
195
- return %0 : vector <1 x16 x16 xf32 >
196
- }
197
241
198
242
// -----
199
243
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
0 commit comments