Skip to content

Commit 66fed33

Browse files
authored
[mlir][vector] Update castAwayContractionLeadingOneDim to omit transposes solely on leading unit dims. (#85694)
Updates `castAwayContractionLeadingOneDim` to check for leading unit dimensions before inserting `vector.transpose` ops. Currently `castAwayContractionLeadingOneDim` removes all leading unit dims based on the accumulator and transpose any subsequent operands to match the accumulator indexing. This does not take into account if the transpose is strictly necessary, for instance when given this vector-matrix contract: ```mlir %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32> ``` Passing this through `castAwayContractionLeadingOneDim` pattern produces the following: ```mlir %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x1x8xi32> to vector<1x1x8xi32> %1 = vector.extract %0[0] : vector<1x8xi32> from vector<1x1x8xi32> %2 = vector.extract %arg2[0] : vector<8xi32> from vector<1x8xi32> %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %arg1, %2 : vector<1x8xi32>, vector<1x8x8xi32> into vector<8xi32> %4 = vector.broadcast %3 : vector<8xi32> to vector<1x8xi32> ``` The `vector.transpose` introduced does not affect the underlying data layout (effectively a no op), but it cannot be folded automatically. This change avoids inserting transposes when only leading unit dimensions are involved. Fixes #85691
1 parent c511c90 commit 66fed33

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
398398
transposeResults.push_back(targetExpr);
399399
}
400400
}
401+
402+
// Checks if only the outer, unit dimensions (of size 1) are permuted.
403+
// Such transposes do not materially effect the underlying vector and can
404+
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
405+
bool transposeNonOuterUnitDims = false;
406+
auto operandShape = operands[it.index()].getType().cast<ShapedType>();
407+
for (auto [index, dim] :
408+
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
409+
if (dim != static_cast<int64_t>(index) &&
410+
operandShape.getDimSize(index) != 1) {
411+
transposeNonOuterUnitDims = true;
412+
break;
413+
}
414+
}
415+
401416
// Do the tranpose now if needed so that we can drop the
402417
// correct dim using extract later.
403418
if (tranposeNeeded) {
404419
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
405420
contractOp.getContext());
406-
operands[it.index()] = rewriter.create<vector::TransposeOp>(
407-
loc, operands[it.index()], perm);
421+
if (transposeNonOuterUnitDims) {
422+
operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
423+
loc, operands[it.index()], perm);
424+
}
408425
}
409426
}
410427
// We have taken care to have the dim to be dropped be

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
238238
return %0: vector<1x1x2x16xf32>
239239
}
240240

241+
// -----
242+
243+
// CHECK-LABEL: func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
244+
// CHECK-NOT vector.transpose
245+
// CHECK: vector.contract
246+
func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
247+
%rhs: vector<1x8x8xi32>,
248+
%acc: vector<1x8xi32>) -> vector<1x8xi32> {
249+
%result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
250+
return %result : vector<1x8xi32>
251+
}
241252

242253
// -----
243254
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
@@ -663,4 +674,3 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>,
663674
%sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
664675
return %sel : vector<1x16xi1>
665676
}
666-

0 commit comments

Comments
 (0)