Skip to content

[mlir][affine] Cancel exactly-matching delinearize/linearize pairs #115758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 12, 2024

Conversation

krzysz00
Copy link
Contributor

If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization.

Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by disjoint, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out.

This commit adds canonicalization patterns for these simple cancelations.

If we linearize values (with an assertion tha they are disjoint) and
then delinearize that linear index with th exact same basis, we know
that these operations are exact inverses of each other and can be
replaced with the original inputs to the linearization.

Similarly, if we take a linear index, delinearize it with some bases,
and then re-linearize it with that same basis (noting that the outputs
of the delinearization are guaranteed to by `disjoint`, even if this
is not asserted on the linearize_index operation), the
re-linearization is the inverse of the delinearization, so those two
operations can also be canceled out.

This commit adds canonicalization patterns for these simple
cancelations.
@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

Changes

If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization.

Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by disjoint, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out.

This commit adds canonicalization patterns for these simple cancelations.


Full diff: https://github.com/llvm/llvm-project/pull/115758.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+63-2)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+96)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..c9d9202ae3cf1a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` `
+    $linear_index `into`
     custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
     attr-dict `:` type($multi_index)
   }];
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..d73d808753ba54 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4625,11 +4625,39 @@ struct DropDelinearizeOneBasisElement
   }
 };
 
+/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
+/// disjoint` and the two operations have the same basis, replace the
+/// delinearizeation results with the inputs of the `affine.linearize_index`
+/// since they are exact inverses of each other.
+///
+/// The `disjoint` flag is needed on the `affine.linearize_index` because
+/// otherwise, there is no guarantee that the inputs to the linearization are
+/// in-bounds the way the outputs of the delinearization would be.
+struct CancelDelinearizeOfLinearizeDisjointExact
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return failure();
+
+    if (!linearizeOp.getDisjoint() ||
+        linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return failure();
+
+    rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis,
+                  CancelDelinearizeOfLinearizeDisjointExact>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4751,12 +4779,45 @@ struct DropLinearizeOneBasisElement final
     return success();
   }
 };
+
+/// Cancel out linearize_index(delinearize_index(x, B), B).
+///
+/// That is, rewrite
+/// ```
+/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
+/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
+/// %bN)
+/// ```
+/// to replacing `%y` with `%x`.
+struct CancelLinearizeOfDelinearizeExact final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto delinearizeOp = linearizeOp.getMultiIndex()
+                             .front()
+                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
+    if (!delinearizeOp)
+      return failure();
+
+    if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+      return failure();
+
+    if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+      return failure();
+
+    rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+               DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..99c115ba782c01 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1526,6 +1526,59 @@ func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index,
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_different_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @delinearize_non_loop_like
 // CHECK-NOT: affine.delinearize
 func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index {
@@ -1577,3 +1630,46 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
+//       CHECK:     return %[[LIN]]
+func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
+  return %1 : index
+}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Maybe leave some time for folks to weigh in if they want (like a day).

} // namespace

void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
DropLinearizeOneBasisElement>(context);
DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please maintain these alphabetically

if (!delinearizeOp)
return failure();

if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Better to return a notifyMatchFailure to know why it didnt match.

@krzysz00 krzysz00 merged commit 49f90e7 into llvm:main Nov 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants