Skip to content

Commit f12639d

Browse files
author
Mahesh Ravishankar
committed
[mlir][Linalg] Avoid collapsing dimensions of linalg op that arent foldable.
The collapsing dimensions transformation is limited to only those cases where the sequence of dimensions are contiguous in all the ranges of the indexing maps of the operation. Add this check before applying the transformation. Differential Revision: https://reviews.llvm.org/D150176
1 parent 92663cd commit f12639d

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,8 +901,21 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op,
901901
const ControlSplitReductionFn &controlSplitReductionFn,
902902
bool useAlloc = false);
903903

904-
/// Collapses dimensions of linalg.generic operation. It also collapses inputs
905-
/// before the op and expands outputs after the op.
904+
/// Return `true` if a given sequence of dimensions are contiguous in the
905+
/// range of the specified indexing map.
906+
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
907+
/// Return `true` if all sequences of dimensions specified in `dimSequences` are
908+
/// contiguous in all the ranges of the `maps`.
909+
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
910+
ArrayRef<ReassociationIndices> dimSequences);
911+
912+
/// Collapses dimensions of linalg.generic operation. A precondition to
913+
/// calling this method is that for each list in `foldedIterationDim`, the
914+
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
915+
/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
916+
/// When valid, the method also collapses the operands of the op. Returns
917+
/// replacement values of the results of the original `genericOp` by inserting
918+
/// reshapes to get back values of compatible types.
906919
FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
907920
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
908921
RewriterBase &rewriter);

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,8 +1004,8 @@ getDomainReassociation(AffineMap indexingMap,
10041004
/// For a given `dimSequence`, check if the sequence is conserved in the
10051005
/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
10061006
/// Non-existence of the sequence returns true as well.
1007-
static bool isDimSequencePreserved(AffineMap indexingMap,
1008-
ReassociationIndicesRef dimSequence) {
1007+
bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
1008+
ReassociationIndicesRef dimSequence) {
10091009
assert(!dimSequence.empty() &&
10101010
"expected non-empty list for dimension sequence");
10111011
assert(indexingMap.isProjectedPermutation() &&
@@ -1045,6 +1045,15 @@ static bool isDimSequencePreserved(AffineMap indexingMap,
10451045
return true;
10461046
}
10471047

1048+
bool mlir::linalg::areDimSequencesPreserved(
1049+
ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1050+
return llvm::all_of(maps, [&](AffineMap map) {
1051+
return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1052+
return isDimSequencePreserved(map, dimSequence);
1053+
});
1054+
});
1055+
}
1056+
10481057
// Return the list of dimensions of the iteration domain that can be
10491058
// collapsed to allow for fusion with the a producer that is an expand_shape
10501059
// operation. If all dimensions created by expansion can be collapsed in the
@@ -1592,6 +1601,13 @@ class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
15921601
if (collapsableIterationDims.empty())
15931602
return failure();
15941603

1604+
// Check if the specified list of dimensions to collapse is a valid list.
1605+
if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
1606+
collapsableIterationDims)) {
1607+
return rewriter.notifyMatchFailure(
1608+
genericOp, "specified dimensions cannot be collapsed");
1609+
}
1610+
15951611
std::optional<SmallVector<Value>> replacements =
15961612
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
15971613
rewriter);

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,20 @@ func.func @collapse_parallel(
5353
// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
5454
// CHECK: } -> tensor<2x32x40960xf32>
5555
// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
56+
57+
// -----
58+
59+
#map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
60+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
61+
func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41xf32>) -> tensor<3x1x57x41xf32> {
62+
%0 = linalg.generic {
63+
indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
64+
ins(%arg0 : tensor<41x3x1x57xf32>) outs(%arg1 : tensor<3x1x57x41xf32>) {
65+
^bb0(%in: f32, %out: f32):
66+
linalg.yield %in : f32
67+
} -> tensor<3x1x57x41xf32>
68+
return %0 : tensor<3x1x57x41xf32>
69+
}
70+
// CHECK-LABEL: func @uncollapsable(
71+
// CHECK: linalg.generic
72+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]

0 commit comments

Comments
 (0)