@@ -238,6 +238,58 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
238
238
return %0: vector <1 x1 x2 x16 xf32 >
239
239
}
240
240
241
+ // -----
242
+
243
+ // CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
244
+ // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
245
+ // CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
246
+
247
+ // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
248
+ // CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
249
+ // CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
250
+ // CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251
+ // CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252
+ // CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
253
+ // CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254
+ // CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
255
+ // CHECK: return %[[VAL_6]] : vector<1x8xi32>
256
+ // CHECK: }
257
+ func.func @cast_away_contraction_leading_one_dims_vec_mat (%lhs: vector <1 x1 x8 xi32 >,
258
+ %rhs: vector <1 x8 x8 xi32 >,
259
+ %acc: vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
260
+ %result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs , %rhs , %acc : vector <1 x1 x8 xi32 >, vector <1 x8 x8 xi32 > into vector <1 x8 xi32 >
261
+ return %result : vector <1 x8 xi32 >
262
+ }
263
+
264
+ // -----
265
+ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
266
+ // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
267
+ // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
268
+
269
+ // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
270
+ // CHECK: %[[MASK:.+]] = vector.constant_mask
271
+ // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
272
+ // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
273
+ // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
274
+ // CHECK: return %[[RET]] : vector<1x16x16xf32>
275
+
276
+ #contraction_accesses0 = [
277
+ affine_map <(l , i , j , k ) -> (l , i , k )>,
278
+ affine_map <(l , i , j , k ) -> (l , k , j )>,
279
+ affine_map <(l , i , j , k ) -> (l , i , j )>
280
+ ]
281
+ #contraction_trait0 = {
282
+ indexing_maps = #contraction_accesses0 ,
283
+ iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
284
+ }
285
+
286
+ func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
287
+ %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
288
+ %0 = vector.mask %mask {
289
+ vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
290
+ } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
291
+ return %0 : vector <1 x16 x16 xf32 >
292
+ }
241
293
242
294
// -----
243
295
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
@@ -663,4 +715,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
663
715
%sel = arith.select %cond , %arg0 , %arg1 : vector <1 x16 xi1 >
664
716
return %sel : vector <1 x16 xi1 >
665
717
}
666
-
0 commit comments