-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][affine] Cancel exactly-matching delinearize/linearize pairs #115758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][affine] Cancel exactly-matching delinearize/linearize pairs #115758
Conversation
If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization. Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by `disjoint`, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out. This commit adds canonicalization patterns for these simple cancelations.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesIf we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization. Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by This commit adds canonicalization patterns for these simple cancelations. Full diff: https://github.com/llvm/llvm-project/pull/115758.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..c9d9202ae3cf1a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1090,7 +1090,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
let results = (outs Variadic<Index>:$multi_index);
let assemblyFormat = [{
- $linear_index `into` ` `
+ $linear_index `into`
custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
attr-dict `:` type($multi_index)
}];
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..d73d808753ba54 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4625,11 +4625,39 @@ struct DropDelinearizeOneBasisElement
}
};
+/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
+/// disjoint` and the two operations have the same basis, replace the
+/// delinearizeation results with the inputs of the `affine.linearize_index`
+/// since they are exact inverses of each other.
+///
+/// The `disjoint` flag is needed on the `affine.linearize_index` because
+/// otherwise, there is no guarantee that the inputs to the linearization are
+/// in-bounds the way the outputs of the delinearization would be.
+struct CancelDelinearizeOfLinearizeDisjointExact
+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return failure();
+
+ if (!linearizeOp.getDisjoint() ||
+ linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+ return failure();
+
+ rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+ patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis,
+ CancelDelinearizeOfLinearizeDisjointExact>(context);
}
//===----------------------------------------------------------------------===//
@@ -4751,12 +4779,45 @@ struct DropLinearizeOneBasisElement final
return success();
}
};
+
+/// Cancel out linearize_index(delinearize_index(x, B), B).
+///
+/// That is, rewrite
+/// ```
+/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
+/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
+/// %bN)
+/// ```
+/// to replacing `%y` with `%x`.
+struct CancelLinearizeOfDelinearizeExact final
+ : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto delinearizeOp = linearizeOp.getMultiIndex()
+ .front()
+ .getDefiningOp<affine::AffineDelinearizeIndexOp>();
+ if (!delinearizeOp)
+ return failure();
+
+ if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+ return failure();
+
+ if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ return failure();
+
+ rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
- DropLinearizeOneBasisElement>(context);
+ DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..99c115ba782c01 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1526,6 +1526,59 @@ func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index,
// -----
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_exact(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_delinearize_linearize_different_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (%[[ARG3]], 4, %[[ARG4]])
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (%[[ARG3]], 8, %[[ARG4]])
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2
+func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (%arg3, 8, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: func @delinearize_non_loop_like
// CHECK-NOT: affine.delinearize
func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index {
@@ -1577,3 +1630,46 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
%ret = affine.linearize_index [%arg0] by (%arg1) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: return %[[ARG0]]
+func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: return %[[LIN]]
+func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
+// CHECK: return %[[LIN]]
+func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
+ return %1 : index
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Maybe leave some time for folks to weigh in if they want (like a day).
} // namespace | ||
|
||
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns( | ||
RewritePatternSet &patterns, MLIRContext *context) { | ||
patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero, | ||
DropLinearizeOneBasisElement>(context); | ||
DropLinearizeOneBasisElement, CancelLinearizeOfDelinearizeExact>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Please maintain these alphabetically
if (!delinearizeOp) | ||
return failure(); | ||
|
||
if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Better to return a notifyMatchFailure
to know why it didnt match.
If we linearize values (with an assertion tha they are disjoint) and then delinearize that linear index with th exact same basis, we know that these operations are exact inverses of each other and can be replaced with the original inputs to the linearization.
Similarly, if we take a linear index, delinearize it with some bases, and then re-linearize it with that same basis (noting that the outputs of the delinearization are guaranteed to by
disjoint
, even if this is not asserted on the linearize_index operation), the re-linearization is the inverse of the delinearization, so those two operations can also be canceled out.This commit adds canonicalization patterns for these simple cancelations.