Skip to content

[mlir][linalg] Enable CollapseLinalgDimensions to collapse ops with C… #70653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,35 @@ def LinalgStructuredInterface
return;
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns true if the indexing map which matches the OpOperand
is considered as a canonicalized identity.
}],
/*retTy=*/"bool",
/*methodName=*/"isCanonicalizedIdentityMap",
/*args=*/(ins "OpOperand*": $opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto indexingMap = $_op.getMatchingIndexingMap(opOperand);
return indexingMap.isCanonicalizedIdentity(getShape(opOperand));
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns true if all of the indexing maps of the specefic linalg operation
are considered as canonicalized identity.
}],
/*retTy=*/"bool",
/*methodName=*/"hasOnlyCanonicalizedIdentityMaps",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::all_of(this->getOperation()->getOpOperands(),[&](OpOperand &opOperand){
return $_op.isCanonicalizedIdentityMap(&opOperand);
});
}]
>,
//===------------------------------------------------------------------===//
// Linalg generalization hooks.
//===------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,9 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op,
/// range of the specified indexing map.
bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
/// Return `true` if all sequences of dimensions specified in `dimSequences` are
/// contiguous in all the ranges of the `maps`.
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
/// contiguous in all the ranges of the indexing maps of the `op`.
template <typename LinalgType>
bool areDimSequencesPreserved(LinalgType op,
ArrayRef<ReassociationIndices> dimSequences);

/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ class AffineMap {
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const;

/// Returns true if this affine map is a canonicalized identity.
/// Otherwise return false.
/// A canonicalized identity affine map corresponds to an identity
/// affine function on the dimensional identifiers. which may
/// include zero constant expressions in the affine map results.
/// These zero constants should be corresponded to dimesnions with
/// value 1.
/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d1, d2, d3, d4)>
/// is considered a canonicalized identity if `shape[0] == 1`.
bool isCanonicalizedIdentity(ArrayRef<int64_t> shape) const;

/// Returns true if this affine map is a minor identity up to broadcasted
/// dimensions which are indicated by value 0 in the result. If
/// `broadcastedDims` is not null, it will be populated with the indices of
Expand Down
58 changes: 41 additions & 17 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,12 +1054,14 @@ bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
// 3. No element of sequence found. Return true.
return true;
}

template <typename LinalgType>
bool mlir::linalg::areDimSequencesPreserved(
ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
return llvm::all_of(maps, [&](AffineMap map) {
LinalgType op, ArrayRef<ReassociationIndices> dimSequences) {
return llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) {
return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
return isDimSequencePreserved(map, dimSequence);
return op.isCanonicalizedIdentityMap(&opOperand) ||
isDimSequencePreserved(op.getMatchingIndexingMap(&opOperand),
dimSequence);
});
});
}
Expand Down Expand Up @@ -1320,17 +1322,31 @@ getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,

/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
template <typename LinalgType>
static AffineMap
getCollapsedOpIndexingMap(AffineMap indexingMap,
getCollapsedOpIndexingMap(LinalgType op, OpOperand &opOperand,
const CollapsingInfo &collapsingInfo) {
auto indexingMap = op.getMatchingIndexingMap(&opOperand);
MLIRContext *context = indexingMap.getContext();
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
assert((op.isCanonicalizedIdentityMap(&opOperand) ||
indexingMap.isProjectedPermutation()) &&
"expected indexing map to be projected permutation or canonicalized "
"identity");
SmallVector<AffineExpr> resultExprs;
auto origOpToCollapsedOpMapping =
collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
unsigned dim;
for (auto pair : llvm::enumerate(indexingMap.getResults())) {
AffineExpr expr = pair.value();
auto constExprt = expr.dyn_cast<AffineConstantExpr>();
if (constExprt) {
assert(!constExprt.getValue() &&
"expected zero constants in canonicalized identity");
dim = pair.index();
} else {
dim = expr.cast<AffineDimExpr>().getPosition();
}

// If the dim is not the first of the collapsed dim, do nothing.
if (origOpToCollapsedOpMapping[dim].second != 0)
continue;
Expand All @@ -1354,9 +1370,17 @@ getOperandReassociation(AffineMap indexingMap,
collapsingInfo.getOrigOpToCollapsedOpMapping();
auto collapsedOpToOrigOpMapping =
collapsingInfo.getCollapsedOpToOrigOpMapping();
unsigned dim;
while (counter < indexingMap.getNumResults()) {
unsigned dim =
indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
AffineExpr expr = indexingMap.getResult(counter);
auto constExprt = expr.dyn_cast<AffineConstantExpr>();
if (constExprt) {
assert(!constExprt.getValue() &&
"expected zero constants in canonicalized identity");
dim = counter;
} else {
dim = expr.cast<AffineDimExpr>().getPosition();
}
// This is the start of a collapsed dimensions of the iteration that
// is gauranteed to be preserved in the indexing map. The number of folded
// dims is obtained from the collapsed op to original op mapping.
Expand Down Expand Up @@ -1480,10 +1504,11 @@ Operation *createCollapsedOp(LinalgType op,
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});
auto indexingMaps = llvm::to_vector(
llvm::map_range(op->getOpOperands(), [&](OpOperand &opOperand) {
return getCollapsedOpIndexingMap<LinalgType>(op, opOperand,
collapsingInfo);
}));

Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
Expand Down Expand Up @@ -1659,8 +1684,7 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
return failure();

// Check if the specified list of dimensions to collapse is a valid list.
if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
collapsableIterationDims)) {
if (!areDimSequencesPreserved<LinalgType>(op, collapsableIterationDims)) {
return rewriter.notifyMatchFailure(
op, "specified dimensions cannot be collapsed");
}
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/IR/AffineMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ bool AffineMap::isMinorIdentity() const {
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
}

bool AffineMap::isCanonicalizedIdentity(ArrayRef<int64_t> shape) const {
if (getNumDims() != getNumResults())
return false;
if (getNumDims() != shape.size())
return false;
for (auto [index, result] : llvm::enumerate(getResults())) {
auto constExpr = result.dyn_cast<AffineConstantExpr>();
if (constExpr && !constExpr.getValue() && shape[index] == 1)
continue;

auto expr = result.dyn_cast<AffineDimExpr>();
if (!expr || expr.getPosition() != index)
return false;
}
return true;
}

/// Returns true if this affine map is a minor identity up to broadcasted
/// dimensions which are indicated by value 0 in the result.
bool AffineMap::isMinorIdentityWithBroadcasting(
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/Linalg/collapse-dim.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,35 @@ func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: me
linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
return
}

// -----

// CHECK-LABEL: func.func @collapse_canonicalized_identity(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x1x4096xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_2]] : tensor<2x2x4096xf32>) outs(%[[VAL_3]] : tensor<2x2x4096xf32>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
// CHECK: %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
// CHECK: linalg.yield %[[VAL_7]] : f32
// CHECK: } -> tensor<2x2x4096xf32>
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_9:.*]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x4096xf32> into tensor<2x2x1x4096xf32>
// CHECK: return %[[VAL_8]] : tensor<2x2x1x4096xf32>
// CHECK: }


func.func @collapse_canonicalized_identity(
%arg0: tensor<2x2x1x4096xf32>, %arg1: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
%0 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<2x2x1x4096xf32>) outs(%arg1 : tensor<2x2x1x4096xf32>) {
^bb0(%arg3: f32, %arg4: f32):
%1 = arith.addf %arg3, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<2x2x1x4096xf32>
return %0 : tensor<2x2x1x4096xf32>
}