Skip to content

[mlir][Affine] Extend linearize/delinearize cancelation to partial tails #116872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024

Conversation

krzysz00
Copy link
Contributor

xisting patterns would cancel out the linearize_index / delinearize_index pairs that had the exact same basis, like

%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...

This commit extends the canonicalization to handle instances where the entire basis doesn't match, as in

%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...

where we can replace the last two results of the delinearize_index operation with the last two inputs of the linearize_index, creating a more canonical (fewer total computations to perform) result.

xisting patterns would cancel out the linearize_index /
delinearize_index pairs that had the exact same basis, like

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...

This commit extends the canonicalization to handle instances where the
entire basis doesn't match, as in

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...

where we can replace the last two results of the delinearize_index
operation with the last two inputs of the linearize_index, creating a
more canonical (fewer total computations to perform) result.
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

xisting patterns would cancel out the linearize_index / delinearize_index pairs that had the exact same basis, like

%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...

This commit extends the canonicalization to handle instances where the entire basis doesn't match, as in

%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...

where we can replace the last two results of the delinearize_index operation with the last two inputs of the linearize_index, creating a more canonical (fewer total computations to perform) result.


Full diff: https://github.com/llvm/llvm-project/pull/116872.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+45-11)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+18)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..67d7da622a3550 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis
 };
 
 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
-/// disjoint` and the two operations have the same basis, replace the
-/// delinearizeation results with the inputs of the `affine.linearize_index`
-/// since they are exact inverses of each other.
+/// disjoint` and the two operations end with the same basis elements,
+/// cancel those parts of the operations out because they are inverses
+/// of each other.
+///
+/// If the operations have the same basis, cancel them entirely.
 ///
 /// The `disjoint` flag is needed on the `affine.linearize_index` because
 /// otherwise, there is no guarantee that the inputs to the linearization are
 /// in-bounds the way the outputs of the delinearization would be.
-struct CancelDelinearizeOfLinearizeDisjointExact
+struct CancelDelinearizeOfLinearizeDisjointExactTail
     : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -4685,12 +4687,45 @@ struct CancelDelinearizeOfLinearizeDisjointExact
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "index doesn't come from linearize");
 
-    if (!linearizeOp.getDisjoint() ||
-        linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+    ValueRange linearizeIns = linearizeOp.getMultiIndex();
+    // Note: we use the full basis so we don't lose outer bounds later.
+    SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
+    SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
+    size_t numMatches = 0;
+    for (auto [linSize, delinSize] : llvm::zip(
+             llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
+      if (linSize != delinSize)
+        break;
+      ++numMatches;
+    }
+
+    if (numMatches == 0)
       return rewriter.notifyMatchFailure(
-          linearizeOp, "not disjoint or basis doesn't match delinearize");
+          delinearizeOp, "final basis element doesn't match linearize");
+
+    // The easy case: everything lines up and the basis match sup completely.
+    if (numMatches == linearizeBasis.size() &&
+        numMatches == delinearizeBasis.size() &&
+        linearizeIns.size() == delinearizeOp.getNumResults()) {
+      rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+      return success();
+    }
 
-    rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+    Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
+        linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
+        ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
+        linearizeOp.getDisjoint());
+    auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        delinearizeOp.getLoc(), newLinearize,
+        ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
+        delinearizeOp.hasOuterBound());
+    SmallVector<Value> mergedResults(newDelinearize.getResults());
+    mergedResults.append(linearizeIns.take_back(numMatches).begin(),
+                         linearizeIns.take_back(numMatches).end());
+    rewriter.replaceOp(delinearizeOp, mergedResults);
     return success();
   }
 };
@@ -4698,9 +4733,8 @@ struct CancelDelinearizeOfLinearizeDisjointExact
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns
-      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
-          context);
+  patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
+                  DropUnitExtentBasis>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..5384977151b47f 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0:
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (8, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // Without `disjoint`, the cancelation isn't guaranteed to be the identity.
 // CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,

@krzysz00 krzysz00 requested a review from Groverkss November 20, 2024 17:46
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! THanks! I think this better demonstrates the power of these operations....

@krzysz00 krzysz00 merged commit 0ac889b into llvm:main Nov 21, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants