Skip to content

Commit b1ff4b0

Browse files
committed
only tranpose non leading unit dims
1 parent c48d818 commit b1ff4b0

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,28 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
399399
transposeResults.push_back(targetExpr);
400400
}
401401
}
402+
403+
// Check if the transpose effects outer unit dims only. Such transposes do
404+
// not materially effect the underlying vector and can be omitted.
405+
bool tranposeNonOuterUnitDims = false;
406+
for (int64_t i = 0; i < (int64_t)perm.size(); ++i) {
407+
if (perm[i] != i && i != (int64_t)perm.size() - 1) {
408+
if (operands[it.index()].getType().cast<ShapedType>().getDimSize(i) !=
409+
1) {
410+
tranposeNonOuterUnitDims = true;
411+
}
412+
}
413+
}
414+
402415
// Do the tranpose now if needed so that we can drop the
403416
// correct dim using extract later.
404417
if (tranposeNeeded) {
405418
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
406419
contractOp.getContext());
407-
operands[it.index()] = rewriter.create<vector::TransposeOp>(
408-
contractOp.getLoc(), operands[it.index()], perm);
420+
if (tranposeNonOuterUnitDims) {
421+
operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
422+
contractOp.getLoc(), operands[it.index()], perm);
423+
}
409424
}
410425
}
411426
// We have taken care to have the dim to be dropped be

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,28 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
166166

167167
// -----
168168

169+
// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
170+
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
171+
// CHECK-DAG: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
172+
173+
// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dims_vec_mat(
174+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x8xi32>,
175+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x8xi32>,
176+
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
177+
// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_0]][0] : vector<1x8xi32> from vector<1x1x8xi32>
178+
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_2]][0] : vector<8xi32> from vector<1x8xi32>
179+
// CHECK: %[[VAL_5:.*]] = vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_3]], %[[VAL_1]], %[[VAL_4]] : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32>
180+
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_5]] : vector<8xi32> to vector<1x8xi32>
181+
// CHECK: return %[[VAL_6]] : vector<1x8xi32>
182+
// CHECK: }
183+
func.func @cast_away_contraction_leading_one_dims_vec_mat(%lhs: vector<1x1x8xi32>,
184+
%rhs: vector<1x8x8xi32>,
185+
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
186+
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
187+
return %result : vector<1x8xi32>
188+
}
189+
190+
// -----
169191
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170192
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171193
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

0 commit comments

Comments
 (0)