Skip to content

[mlir][linalg] Enable CollapseLinalgDimensions to collapse ops with C… #70653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

amirBish
Copy link
Contributor

@amirBish amirBish commented Oct 30, 2023

…anonicalized Identity maps

Supporting collapsion of linalg ops which have
canonicalized identity maps matched for their
OpOperands.

Canonnicalized Identity is an identity affine map
which include zero constants corresponded to the
values of 1 of the Operand's shape.

a common use case for this support would be the
usage of CollapseLinalgDimensions after Tosa-To-Linalg ,
since the later generates linalg.generic ops with canonicalized
identity maps (and the rewrite pattern would fail matching,
since it supports only projected permutes indexing maps).

…anonicalized Identity maps

Supporting collapsion of linalg ops which have
canonicalized identity maps matched for their
OpOperands.

Canonnicalized Identity is an identity affine map
which include zero constants corresponded to the
values of `1` of the Operand's shape.

a common use case for this support would be the
usage of CollapseLinalgDimensions after Tosa-To-Linalg
, since the later generates linalg.generic ops with
canonicalized identity maps (and the rewrite pattern
would fail matching, since it supports only projected
permutes indexing maps).
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir labels Oct 30, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2023

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Amir Bishara (amirBish)

Changes

…anonicalized Identity maps

Supporting collapsion of linalg ops which have
canonicalized identity maps matched for their
OpOperands.

Canonnicalized Identity is an identity affine map
which include zero constants corresponded to the
values of 1 of the Operand's shape.

a common use case for this support would be the
usage of CollapseLinalgDimensions after Tosa-To-Linalg ,
since the later generates linalg.generic ops with canonicalized
identity maps (and the rewrite pattern would fail matching,
since it supports only projected permutes indexing maps).


Full diff: https://github.com/llvm/llvm-project/pull/70653.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+29)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+3-2)
  • (modified) mlir/include/mlir/IR/AffineMap.h (+11)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+41-17)
  • (modified) mlir/lib/IR/AffineMap.cpp (+17)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69ca888a8acdbe0..31efa35540b25e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -668,6 +668,35 @@ def LinalgStructuredInterface
         return;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if the indexing map which matches the OpOperand
+        is considered as a canonicalized identity.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isCanonicalizedIdentityMap",
+      /*args=*/(ins "OpOperand*": $opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+          auto indexingMap = $_op.getMatchingIndexingMap(opOperand);
+          return indexingMap.isCanonicalizedIdentity(getShape(opOperand));
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if all of the indexing maps of the specefic linalg operation
+        are considered as canonicalized identity.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasOnlyCanonicalizedIdentityMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+          return llvm::all_of(this->getOperation()->getOpOperands(),[&](OpOperand &opOperand){
+            return $_op.isCanonicalizedIdentityMap(&opOperand);
+          });
+      }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fbe2923c710aabb..b7c769ed3560ee8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1043,8 +1043,9 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op,
 /// range of the specified indexing map.
 bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 /// Return `true` if all sequences of dimensions specified in `dimSequences` are
-/// contiguous in all the ranges of the `maps`.
-bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
+/// contiguous in all the ranges of the indexing maps of the `op`.
+template <typename LinalgType>
+bool areDimSequencesPreserved(LinalgType op,
                               ArrayRef<ReassociationIndices> dimSequences);
 
 /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..d446e1500845406 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -131,6 +131,17 @@ class AffineMap {
   /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
   bool isMinorIdentity() const;
 
+  /// Returns true if this affine map is a canonicalized identity.
+  /// Otherwise return false.
+  /// A canonicalized identity affine map corresponds to an identity
+  /// affine function on the dimensional identifiers. which may
+  /// include zero constant expressions in the affine map results.
+  /// These zero constants should be corresponded to dimesnions with
+  /// value 1.
+  /// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d1, d2, d3, d4)>
+  /// is considered a canonicalized identity if `shape[0] == 1`.
+  bool isCanonicalizedIdentity(ArrayRef<int64_t> shape) const;
+
   /// Returns true if this affine map is a minor identity up to broadcasted
   /// dimensions which are indicated by value 0 in the result. If
   /// `broadcastedDims` is not null, it will be populated with the indices of
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 32d38a21e4e00f4..e2bdbebb831e5c4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1054,12 +1054,14 @@ bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
   // 3. No element of sequence found. Return true.
   return true;
 }
-
+template <typename LinalgType>
 bool mlir::linalg::areDimSequencesPreserved(
-    ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
-  return llvm::all_of(maps, [&](AffineMap map) {
+    LinalgType op, ArrayRef<ReassociationIndices> dimSequences) {
+  return llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) {
     return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
-      return isDimSequencePreserved(map, dimSequence);
+      return op.isCanonicalizedIdentityMap(&opOperand) ||
+             isDimSequencePreserved(op.getMatchingIndexingMap(&opOperand),
+                                    dimSequence);
     });
   });
 }
@@ -1320,17 +1322,31 @@ getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
 
 /// Compute the indexing map in the collapsed op that corresponds to the given
 /// `indexingMap` of the original operation.
+template <typename LinalgType>
 static AffineMap
-getCollapsedOpIndexingMap(AffineMap indexingMap,
+getCollapsedOpIndexingMap(LinalgType op, OpOperand &opOperand,
                           const CollapsingInfo &collapsingInfo) {
+  auto indexingMap = op.getMatchingIndexingMap(&opOperand);
   MLIRContext *context = indexingMap.getContext();
-  assert(indexingMap.isProjectedPermutation() &&
-         "expected indexing map to be projected permutation");
+  assert((op.isCanonicalizedIdentityMap(&opOperand) ||
+          indexingMap.isProjectedPermutation()) &&
+         "expected indexing map to be projected permutation or canonicalized "
+         "identity");
   SmallVector<AffineExpr> resultExprs;
   auto origOpToCollapsedOpMapping =
       collapsingInfo.getOrigOpToCollapsedOpMapping();
-  for (auto expr : indexingMap.getResults()) {
-    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
+  unsigned dim;
+  for (auto pair : llvm::enumerate(indexingMap.getResults())) {
+    AffineExpr expr = pair.value();
+    auto constExprt = expr.dyn_cast<AffineConstantExpr>();
+    if (constExprt) {
+      assert(!constExprt.getValue() &&
+             "expected zero constants in canonicalized identity");
+      dim = pair.index();
+    } else {
+      dim = expr.cast<AffineDimExpr>().getPosition();
+    }
+
     // If the dim is not the first of the collapsed dim, do nothing.
     if (origOpToCollapsedOpMapping[dim].second != 0)
       continue;
@@ -1354,9 +1370,17 @@ getOperandReassociation(AffineMap indexingMap,
       collapsingInfo.getOrigOpToCollapsedOpMapping();
   auto collapsedOpToOrigOpMapping =
       collapsingInfo.getCollapsedOpToOrigOpMapping();
+  unsigned dim;
   while (counter < indexingMap.getNumResults()) {
-    unsigned dim =
-        indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
+    AffineExpr expr = indexingMap.getResult(counter);
+    auto constExprt = expr.dyn_cast<AffineConstantExpr>();
+    if (constExprt) {
+      assert(!constExprt.getValue() &&
+             "expected zero constants in canonicalized identity");
+      dim = counter;
+    } else {
+      dim = expr.cast<AffineDimExpr>().getPosition();
+    }
     // This is the start of a collapsed dimensions of the iteration that
     // is gauranteed to be preserved in the indexing map. The number of folded
     // dims is obtained from the collapsed op to original op mapping.
@@ -1480,10 +1504,11 @@ Operation *createCollapsedOp(LinalgType op,
       getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
 
   // Get the indexing maps.
-  auto indexingMaps =
-      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
-        return getCollapsedOpIndexingMap(map, collapsingInfo);
-      });
+  auto indexingMaps = llvm::to_vector(
+      llvm::map_range(op->getOpOperands(), [&](OpOperand &opOperand) {
+        return getCollapsedOpIndexingMap<LinalgType>(op, opOperand,
+                                                     collapsingInfo);
+      }));
 
   Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
       loc, resultTypes, inputOperands, outputOperands, indexingMaps,
@@ -1659,8 +1684,7 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
       return failure();
 
     // Check if the specified list of dimensions to collapse is a valid list.
-    if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
-                                  collapsableIterationDims)) {
+    if (!areDimSequencesPreserved<LinalgType>(op, collapsableIterationDims)) {
       return rewriter.notifyMatchFailure(
           op, "specified dimensions cannot be collapsed");
     }
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..a10ffb7bdd2b3b0 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -133,6 +133,23 @@ bool AffineMap::isMinorIdentity() const {
              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
 }
 
+bool AffineMap::isCanonicalizedIdentity(ArrayRef<int64_t> shape) const {
+  if (getNumDims() != getNumResults())
+    return false;
+  if (getNumDims() != shape.size())
+    return false;
+  for (auto [index, result] : llvm::enumerate(getResults())) {
+    auto constExpr = result.dyn_cast<AffineConstantExpr>();
+    if (constExpr && !constExpr.getValue() && shape[index] == 1)
+      continue;
+
+    auto expr = result.dyn_cast<AffineDimExpr>();
+    if (!expr || expr.getPosition() != index)
+      return false;
+  }
+  return true;
+}
+
 /// Returns true if this affine map is a minor identity up to broadcasted
 /// dimensions which are indicated by value 0 in the result.
 bool AffineMap::isMinorIdentityWithBroadcasting(
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f53387477..ed375ce703b41ff 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -153,3 +153,35 @@ func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: me
   linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
   return
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @collapse_canonicalized_identity(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<2x2x1x4096xf32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x1x4096xf32> into tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_2]] : tensor<2x2x4096xf32>) outs(%[[VAL_3]] : tensor<2x2x4096xf32>) {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
+// CHECK:             %[[VAL_7:.*]] = arith.addf %[[VAL_5]], %[[VAL_6]] : f32
+// CHECK:             linalg.yield %[[VAL_7]] : f32
+// CHECK:           } -> tensor<2x2x4096xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_9:.*]] {{\[\[}}0], [1], [2, 3]] : tensor<2x2x4096xf32> into tensor<2x2x1x4096xf32>
+// CHECK:           return %[[VAL_8]] : tensor<2x2x1x4096xf32>
+// CHECK:         }
+
+
+func.func @collapse_canonicalized_identity(
+    %arg0: tensor<2x2x1x4096xf32>, %arg1: tensor<2x2x1x4096xf32>) -> tensor<2x2x1x4096xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, d3)>,
+        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  ins(%arg0 : tensor<2x2x1x4096xf32>) outs(%arg1 : tensor<2x2x1x4096xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32):
+    %1 = arith.addf %arg3, %arg4 : f32
+    linalg.yield %1 : f32
+  } -> tensor<2x2x1x4096xf32>
+  return %0 : tensor<2x2x1x4096xf32>
+}
\ No newline at end of file

@amirBish
Copy link
Contributor Author

Adding @aniragil @AviadCo @amrami as subscribers.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill review this, but in general IMO having 0s in indexing maps is an anti-pattern... Have you tried using FoldUnitExtentDims pass... it is meant to convert the numpy-style broadcasting into more canonical Linalg broadcasting.

@nicolasvasilache
Copy link
Contributor

Ill review this, but in general IMO having 0s in indexing maps is an anti-pattern... Have you tried using FoldUnitExtentDims pass... it is meant to convert the numpy-style broadcasting into more canonical Linalg broadcasting.

+1 this seems undesirable atm

@amirBish
Copy link
Contributor Author

amirBish commented Oct 31, 2023

@MaheshRavishankar @nicolasvasilache Thanks for the quick response, searched for such a pattern and missed this one :( , I agree usage of FoldUnitExtentDims would solve my problem.
would close this PR, thanks again.

@amirBish amirBish closed this Oct 31, 2023
@nicolasvasilache
Copy link
Contributor

nicolasvasilache commented Oct 31, 2023

@amirBish I am sorry that it was hard to find .. please do not hesitate to ping us proactively on discord or discourse to help accelerate thinking. I would also suggest you skim through the various transform. transform dialect operations which should give you an understanding of what exists and can help make discovery easier.

I sometimes run the following command to help me skim through existing things:

git grep -o -e "[[:space:]]transform\.[a-z_\.]\+" mlir/test/ | awk {'print $2'} | sort | uniq

@amirBish
Copy link
Contributor Author

@nicolasvasilache Great, thanks for the tips :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants