@@ -240,57 +240,16 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
240
240
241
241
// -----
242
242
243
- // CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
244
- // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
245
- // CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
246
-
247
- // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
248
- // CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
249
- // CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
250
- // CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251
- // CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252
- // CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
253
- // CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254
- // CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
255
- // CHECK: return %[[VAL_6]] : vector<1x8xi32>
256
- // CHECK: }
257
- func.func @cast_away_contraction_leading_one_dims_vec_mat (%lhs: vector <1 x1 x8 xi32 >,
243
+ // CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
244
+ // CHECK-NOT vector.transpose
245
+ // CHECK: vector.contract
246
+ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims (%lhs: vector <1 x1 x8 xi32 >,
258
247
%rhs: vector <1 x8 x8 xi32 >,
259
248
%acc: vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
260
249
%result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs , %rhs , %acc : vector <1 x1 x8 xi32 >, vector <1 x8 x8 xi32 > into vector <1 x8 xi32 >
261
250
return %result : vector <1 x8 xi32 >
262
251
}
263
252
264
- // -----
265
- // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
266
- // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
267
- // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
268
-
269
- // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
270
- // CHECK: %[[MASK:.+]] = vector.constant_mask
271
- // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
272
- // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
273
- // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
274
- // CHECK: return %[[RET]] : vector<1x16x16xf32>
275
-
276
- #contraction_accesses0 = [
277
- affine_map <(l , i , j , k ) -> (l , i , k )>,
278
- affine_map <(l , i , j , k ) -> (l , k , j )>,
279
- affine_map <(l , i , j , k ) -> (l , i , j )>
280
- ]
281
- #contraction_trait0 = {
282
- indexing_maps = #contraction_accesses0 ,
283
- iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
284
- }
285
-
286
- func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
287
- %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
288
- %0 = vector.mask %mask {
289
- vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
290
- } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
291
- return %0 : vector <1 x16 x16 xf32 >
292
- }
293
-
294
253
// -----
295
254
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
296
255
func.func @cast_away_extract_strided_slice_leading_one_dims (%arg0: vector <1 x8 x8 xf16 >) -> vector <1 x1 x8 xf16 > {
0 commit comments