Skip to content

Commit f40efe1

Browse files
committed
review comments
1 parent d551d9d commit f40efe1

File tree

2 files changed

+20
-48
lines changed

2 files changed

+20
-48
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,15 +399,16 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
399399
}
400400
}
401401

402-
// Check if the transpose effects outer unit dims only. Such transposes do
403-
// not materially effect the underlying vector and can be omitted.
402+
// Checks if only the outer, unit dimensions (of size 1) are permuted.
403+
// Such transposes do not materially effect the underlying vector and can
404+
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
404405
bool tranposeNonOuterUnitDims = false;
405-
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
406-
if (perm[i] != i && i != (int64_t)perm.size() - 1) {
407-
if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
408-
1) {
409-
tranposeNonOuterUnitDims = true;
410-
}
406+
for (auto [index, dim] :
407+
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
408+
if (dim != static_cast<int64_t>(index) &&
409+
operands[it.index()].getType().cast<ShapedType>().getDimSize(
410+
index) != 1) {
411+
tranposeNonOuterUnitDims = true;
411412
}
412413
}
413414

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -244,53 +244,24 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
244244
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
245245
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
246246

247-
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
248-
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
249-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
250-
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251-
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252-
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
253-
// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254-
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
255-
// CHECK: return %[[VAL_6]] : vector<1x8xi32>
247+
// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(
248+
// CHECK-SAME: %[[LHS:.*]]: vector<1x1x8xi32>,
249+
// CHECK-SAME: %[[RHS:.*]]: vector<1x8x8xi32>,
250+
// CHECK-SAME: %[[ACC:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
251+
// CHECK: %[[EXT_LHS:.*]] = vector.extract %[[LHS]][0] : vector<1x8xi32> from vector<1x1x8xi32>
252+
// CHECK: %[[EXT_ACC:.*]] = vector.extract %[[ACC]][0] : vector<8xi32> from vector<1x8xi32>
253+
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS]], %[[RHS]], %[[EXT_ACC]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
254+
// CHECK: %[[BROADCAST_RES:.*]] = vector.broadcast %[[RES]] : vector<8xi32> to vector<1x8xi32>
255+
// CHECK: return %[[BROADCAST_RES]] : vector<1x8xi32>
256256
// CHECK: }
257-
func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
257+
// CHECK-NOT vector.transpose
258+
func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
258259
%rhs: vector<1x8x8xi32>,
259260
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
260261
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
261262
return %result : vector<1x8xi32>
262263
}
263264

264-
// -----
265-
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
266-
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
267-
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
268-
269-
// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
270-
// CHECK: %[[MASK:.+]] = vector.constant_mask
271-
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
272-
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
273-
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
274-
// CHECK: return %[[RET]] : vector<1x16x16xf32>
275-
276-
#contraction_accesses0 = [
277-
affine_map<(l, i, j, k) -> (l, i, k)>,
278-
affine_map<(l, i, j, k) -> (l, k, j)>,
279-
affine_map<(l, i, j, k) -> (l, i, j)>
280-
]
281-
#contraction_trait0 = {
282-
indexing_maps = #contraction_accesses0,
283-
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
284-
}
285-
286-
func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
287-
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
288-
%0 = vector.mask %mask {
289-
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
290-
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
291-
return %0 : vector<1x16x16xf32>
292-
}
293-
294265
// -----
295266
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
296267
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {

0 commit comments

Comments
 (0)