Skip to content

Commit 7975d0b

Browse files
committed
only tranpose non leading unit dims
1 parent c511c90 commit 7975d0b

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
398398
transposeResults.push_back(targetExpr);
399399
}
400400
}
401+
402+
// Check if the transpose effects outer unit dims only. Such transposes do
403+
// not materially effect the underlying vector and can be omitted.
404+
bool tranposeNonOuterUnitDims = false;
405+
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
406+
if (perm[i] != i && i != (int64_t)perm.size() - 1) {
407+
if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
408+
1) {
409+
tranposeNonOuterUnitDims = true;
410+
}
411+
}
412+
}
413+
401414
// Do the tranpose now if needed so that we can drop the
402415
// correct dim using extract later.
403416
if (tranposeNeeded) {
404417
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
405418
contractOp.getContext());
406-
operands[it.index()] = rewriter.create<vector::TransposeOp>(
407-
loc, operands[it.index()], perm);
419+
if (tranposeNonOuterUnitDims) {
420+
operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
421+
loc, operands[it.index()], perm);
422+
}
408423
}
409424
}
410425
// We have taken care to have the dim to be dropped be

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,58 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
238238
return %0: vector<1x1x2x16xf32>
239239
}
240240

241+
// -----
242+
243+
// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
244+
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
245+
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
246+
247+
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
248+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
249+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
250+
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251+
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252+
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
253+
// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254+
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
255+
// CHECK: return %[[VAL_6]] : vector<1x8xi32>
256+
// CHECK: }
257+
func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
258+
%rhs: vector<1x8x8xi32>,
259+
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
260+
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
261+
return %result : vector<1x8xi32>
262+
}
263+
264+
// -----
265+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
266+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
267+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
268+
269+
// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
270+
// CHECK: %[[MASK:.+]] = vector.constant_mask
271+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
272+
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
273+
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
274+
// CHECK: return %[[RET]] : vector<1x16x16xf32>
275+
276+
#contraction_accesses0 = [
277+
affine_map<(l, i, j, k) -> (l, i, k)>,
278+
affine_map<(l, i, j, k) -> (l, k, j)>,
279+
affine_map<(l, i, j, k) -> (l, i, j)>
280+
]
281+
#contraction_trait0 = {
282+
indexing_maps = #contraction_accesses0,
283+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
284+
}
285+
286+
func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
287+
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
288+
%0 = vector.mask %mask {
289+
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
290+
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
291+
return %0 : vector<1x16x16xf32>
292+
}
241293

242294
// -----
243295
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
@@ -663,4 +715,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
663715
%sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
664716
return %sel : vector<1x16xi1>
665717
}
666-

0 commit comments

Comments
 (0)