@@ -244,53 +244,24 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
244
244
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
245
245
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
246
246
247
- // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat (
248
- // CHECK-SAME: %[[VAL_0 :.*]]: vector<1x1x8xi32>,
249
- // CHECK-SAME: %[[VAL_1 :.*]]: vector<1x8x8xi32>,
250
- // CHECK-SAME: %[[VAL_2 :.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251
- // CHECK: %[[VAL_3 :.*]] = vector.extract %[[VAL_0 ]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252
- // CHECK: %[[VAL_4 :.*]] = vector.extract %[[VAL_2 ]][0] : vector<8xi32> from vector<1x8xi32>
253
- // CHECK: %[[VAL_5 :.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3 ]], %[[VAL_1 ]], %[[VAL_4 ]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254
- // CHECK: %[[VAL_6 :.*]] = vector.broadcast %[[VAL_5 ]] : vector<8xi32> to vector<1x8xi32>
255
- // CHECK: return %[[VAL_6 ]] : vector<1x8xi32>
247
+ // CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims (
248
+ // CHECK-SAME: %[[LHS :.*]]: vector<1x1x8xi32>,
249
+ // CHECK-SAME: %[[RHS :.*]]: vector<1x8x8xi32>,
250
+ // CHECK-SAME: %[[ACC :.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251
+ // CHECK: %[[EXT_LHS :.*]] = vector.extract %[[LHS ]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252
+ // CHECK: %[[EXT_ACC :.*]] = vector.extract %[[ACC ]][0] : vector<8xi32> from vector<1x8xi32>
253
+ // CHECK: %[[RES :.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS ]], %[[RHS ]], %[[EXT_ACC ]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254
+ // CHECK: %[[BROADCAST_RES :.*]] = vector.broadcast %[[RES ]] : vector<8xi32> to vector<1x8xi32>
255
+ // CHECK: return %[[BROADCAST_RES ]] : vector<1x8xi32>
256
256
// CHECK: }
257
- func.func @cast_away_contraction_leading_one_dims_vec_mat (%lhs: vector <1 x1 x8 xi32 >,
257
+ // CHECK-NOT vector.transpose
258
+ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims (%lhs: vector <1 x1 x8 xi32 >,
258
259
%rhs: vector <1 x8 x8 xi32 >,
259
260
%acc: vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
260
261
%result = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 )>, affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %lhs , %rhs , %acc : vector <1 x1 x8 xi32 >, vector <1 x8 x8 xi32 > into vector <1 x8 xi32 >
261
262
return %result : vector <1 x8 xi32 >
262
263
}
263
264
264
- // -----
265
- // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
266
- // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
267
- // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
268
-
269
- // CHECK-LABEL: not_insert_cast_for_contraction_under_mask
270
- // CHECK: %[[MASK:.+]] = vector.constant_mask
271
- // CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
272
- // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
273
- // CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
274
- // CHECK: return %[[RET]] : vector<1x16x16xf32>
275
-
276
- #contraction_accesses0 = [
277
- affine_map <(l , i , j , k ) -> (l , i , k )>,
278
- affine_map <(l , i , j , k ) -> (l , k , j )>,
279
- affine_map <(l , i , j , k ) -> (l , i , j )>
280
- ]
281
- #contraction_trait0 = {
282
- indexing_maps = #contraction_accesses0 ,
283
- iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]
284
- }
285
-
286
- func.func @not_insert_cast_for_contraction_under_mask (%arg0: vector <1 x16 x8 xf32 >, %arg1: vector <1 x8 x16 xf32 >, %arg2: vector <1 x16 x16 xf32 >) -> vector <1 x16 x16 xf32 > {
287
- %mask = vector.constant_mask [1 , 15 , 15 , 8 ] : vector <1 x16 x16 x8 xi1 >
288
- %0 = vector.mask %mask {
289
- vector.contract #contraction_trait0 %arg0 , %arg1 , %arg2 : vector <1 x16 x8 xf32 >, vector <1 x8 x16 xf32 > into vector <1 x16 x16 xf32 >
290
- } : vector <1 x16 x16 x8 xi1 > -> vector <1 x16 x16 xf32 >
291
- return %0 : vector <1 x16 x16 xf32 >
292
- }
293
-
294
265
// -----
295
266
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
296
267
func.func @cast_away_extract_strided_slice_leading_one_dims (%arg0: vector <1 x8 x8 xf16 >) -> vector <1 x1 x8 xf16 > {
0 commit comments