-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Affine] Split off delinearize parts that depend on last component #117015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Affine] Split off delinearize parts that depend on last component #117015
Conversation
If we have %0 = affine.linearize_index disjoint [%a, %b] by (A, B) %1:3 = affine.delinearize_index %0 into (A, B1, B2) where B = B1 * B2 (or some mor complex product), we can simplify this to %0 = affine.linearize_index disjoint [%a] by (A) %1a:1 = affine.delinearize_index %0 into (A) %1b:2 = affine.delinearize_index %b into (B1, B2) This, and more complex cases, prevent us from adding terms together only to divide them away from each other.
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesIf we have This, and more complex cases, prevent us from adding terms together only to divide them away from each other. Full diff: https://github.com/llvm/llvm-project/pull/117015.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..b13331abc32ada 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
return success();
}
};
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
+/// %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+ : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "index doesn't come from linearize");
+
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp,
+ "linearize isn't disjoint");
+
+ int64_t target = linearizeOp.getStaticBasis().back();
+ if (ShapedType::isDynamic(target))
+ return rewriter.notifyMatchFailure(
+ linearizeOp, "linearize ends with dynamic basis value");
+
+ int64_t sizeToSplit = 1;
+ size_t elemsToSplit = 0;
+ ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+ for (int64_t basisElem : llvm::reverse(basis)) {
+ if (ShapedType::isDynamic(basisElem))
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "dynamic basis element while scanning for split");
+ sizeToSplit *= basisElem;
+ elemsToSplit += 1;
+
+ if (sizeToSplit > target)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "overshot last argument size");
+ if (sizeToSplit == target)
+ break;
+ }
+
+ if (sizeToSplit < target)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "product of known basis elements doesn't exceed last "
+ "linearize argument");
+
+ if (elemsToSplit < 2)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "don't have a non-trivial basis product");
+
+ Value linearizeWithoutBack =
+ rewriter.create<affine::AffineLinearizeIndexOp>(
+ linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+ linearizeOp.getDynamicBasis(),
+ linearizeOp.getStaticBasis().drop_back(),
+ linearizeOp.getDisjoint());
+ auto delinearizeWithoutSplitPart =
+ rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeWithoutBack,
+ delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+ delinearizeOp.hasOuterBound());
+ auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+ basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+ SmallVector<Value> results = llvm::to_vector(
+ llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+ delinearizeBack.getResults()));
+ rewriter.replaceOp(delinearizeOp, results);
+
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns
- .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
- context);
+ .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
+ SplitDelinearizeSpanningLastLinearizeArg>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..efeea7eb2af530 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
// -----
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+ %1:3 = affine.delinearize_index %0 into (4, 8)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
|
@llvm/pr-subscribers-mlir-affine Author: Krzysztof Drewniak (krzysz00) ChangesIf we have This, and more complex cases, prevent us from adding terms together only to divide them away from each other. Full diff: https://github.com/llvm/llvm-project/pull/117015.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..b13331abc32ada 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4694,13 +4694,96 @@ struct CancelDelinearizeOfLinearizeDisjointExact
return success();
}
};
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
+/// %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+ : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "index doesn't come from linearize");
+
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp,
+ "linearize isn't disjoint");
+
+ int64_t target = linearizeOp.getStaticBasis().back();
+ if (ShapedType::isDynamic(target))
+ return rewriter.notifyMatchFailure(
+ linearizeOp, "linearize ends with dynamic basis value");
+
+ int64_t sizeToSplit = 1;
+ size_t elemsToSplit = 0;
+ ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+ for (int64_t basisElem : llvm::reverse(basis)) {
+ if (ShapedType::isDynamic(basisElem))
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "dynamic basis element while scanning for split");
+ sizeToSplit *= basisElem;
+ elemsToSplit += 1;
+
+ if (sizeToSplit > target)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "overshot last argument size");
+ if (sizeToSplit == target)
+ break;
+ }
+
+ if (sizeToSplit < target)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "product of known basis elements doesn't exceed last "
+ "linearize argument");
+
+ if (elemsToSplit < 2)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "don't have a non-trivial basis product");
+
+ Value linearizeWithoutBack =
+ rewriter.create<affine::AffineLinearizeIndexOp>(
+ linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+ linearizeOp.getDynamicBasis(),
+ linearizeOp.getStaticBasis().drop_back(),
+ linearizeOp.getDisjoint());
+ auto delinearizeWithoutSplitPart =
+ rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeWithoutBack,
+ delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+ delinearizeOp.hasOuterBound());
+ auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+ basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+ SmallVector<Value> results = llvm::to_vector(
+ llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+ delinearizeBack.getResults()));
+ rewriter.replaceOp(delinearizeOp, results);
+
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns
- .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
- context);
+ .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis,
+ SplitDelinearizeSpanningLastLinearizeArg>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..efeea7eb2af530 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1777,6 +1777,72 @@ func.func @no_cancel_delinearize_linearize_different_basis(%arg0: index, %arg1:
// -----
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+// CHECK: %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+// CHECK: %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+// CHECK: return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+ %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+ : index, index, index, index
+ return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+ %1:3 = affine.delinearize_index %0 into (4, 8)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// CHECK-LABEL: @linearize_unit_basis_disjoint
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing PR @krzysz00 ! Approving with two requests! Thanks!
if (elemsToSplit < 2) | ||
return rewriter.notifyMatchFailure( | ||
delinearizeOp, "don't have a non-trivial basis product"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe add a lit test to test this path as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's implicit in an existing test that permutes but I'll go add another one if it isn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... Ok, yeah, the case where the final components agree exactly gets sent down a different pattern, is what's going on
Co-authored-by: Abhishek Varma <[email protected]>
If we have
where B = B1 * B2 (or some mor complex product), we can simplify this to
This, and more complex cases, prevent us from adding terms together only to divide them away from each other.