Skip to content

Commit 163ea73

Browse files
committed
review comments
1 parent 5f038e1 commit 163ea73

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,16 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
400400
}
401401
}
402402

403-
// Check if the transpose effects outer unit dims only. Such transposes do
404-
// not materially effect the underlying vector and can be omitted.
403+
// Checks if only the outer, unit dimensions (of size 1) are permuted.
404+
// Such transposes do not materially effect the underlying vector and can
405+
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
405406
bool tranposeNonOuterUnitDims = false;
406-
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
407-
if (perm[i] != i && i != (int64_t)perm.size() - 1) {
408-
if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
409-
1) {
410-
tranposeNonOuterUnitDims = true;
411-
}
407+
for (auto [index, dim] :
408+
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
409+
if (dim != static_cast<int64_t>(index) &&
410+
operands[it.index()].getType().cast<ShapedType>().getDimSize(
411+
index) != 1) {
412+
tranposeNonOuterUnitDims = true;
412413
}
413414
}
414415

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,18 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
170170
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171171
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
172172

173-
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
174-
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
175-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
176-
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177-
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178-
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
179-
// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180-
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
181-
// CHECK: return %[[VAL_6]] : vector<1x8xi32>
173+
// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(
174+
// CHECK-SAME: %[[LHS:.*]]: vector<1x1x8xi32>,
175+
// CHECK-SAME: %[[RHS:.*]]: vector<1x8x8xi32>,
176+
// CHECK-SAME: %[[ACC:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177+
// CHECK: %[[EXT_LHS:.*]] = vector.extract %[[LHS]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178+
// CHECK: %[[EXT_ACC:.*]] = vector.extract %[[ACC]][0] : vector<8xi32> from vector<1x8xi32>
179+
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[EXT_LHS]], %[[RHS]], %[[EXT_ACC]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180+
// CHECK: %[[BROADCAST_RES:.*]] = vector.broadcast %[[RES]] : vector<8xi32> to vector<1x8xi32>
181+
// CHECK: return %[[BROADCAST_RES]] : vector<1x8xi32>
182182
// CHECK: }
183-
func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
183+
// CHECK-NOT vector.transpose
184+
func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
184185
%rhs: vector<1x8x8xi32>,
185186
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
186187
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>

0 commit comments

Comments
 (0)