-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] remove tensor.insert constant folding out of canonicalization #142671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (asraa) ChangesFollow ups from #142458 Full diff: https://github.com/llvm/llvm-project/pull/142671.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index e8e1342ef36fd..447f5b906cad1 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,9 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
+/// Patterns to fold inserts into a constant into a new constant.
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns);
+
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
/// source tensor.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
- let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..c051ef4ae6bb0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1653,7 +1653,7 @@ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
return failure();
// Pattern requires constant indices
- SmallVector<uint64_t, 8> indices;
+ SmallVector<uint64_t> indices;
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
auto indiceAttr = dyn_cast<Attribute>(indice);
if (!indiceAttr)
@@ -1717,9 +1717,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<InsertOpConstantFold>(context);
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns) {
+ patterns.add<InsertOpConstantFold>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
- // Fold an insert into a splat.
- // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
- // CHECK-LITERAL:
- // CHECK-NEXT: return %[[C4]]
- %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4_i32 = arith.constant 4 : i32
- %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
- return %inserted : tensor<2x2xi32>
-}
-
// -----
// CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/insert-after-constant.mlir b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
new file mode 100644
index 0000000000000..73f49ac6eba78
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-insert-after-constant %s | FileCheck %s
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+ // Fold an insert into a splat.
+ // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+ // CHECK-LITERAL:
+ // CHECK-NEXT: return %[[C4]]
+ %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4_i32 = arith.constant 4 : i32
+ %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+ return %inserted : tensor<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 0e191c32f009e..f2750c5e9a0de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -82,6 +82,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of extract from collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testFoldInsertAfterConstant{
+ *this, "test-fold-insert-after-constant",
+ llvm::cl::desc("Test folding of insert of a constant"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -143,6 +148,12 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
+static void applyFoldInsertAfterConstantPattern(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldInsertAfterConstant(patterns);
+ (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -393,6 +404,8 @@ void TestTensorTransforms::runOnOperation() {
}
if (testFoldExtractFromCollapseShape)
applyFoldExtractFromCollapseShapePatterns(rootOp);
+ if (testFoldInsertAfterConstant)
+ applyFoldInsertAfterConstantPattern(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
|
cc @joker-eph this removes the pattern from canonicalization and fixes those comments from the PR |
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Asra Ali <[email protected]> fix Signed-off-by: Asra Ali <[email protected]> fix loop Signed-off-by: Asra Ali <[email protected]> remove pattern Signed-off-by: Asra Ali <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to add to the trouble here. While I am all for removing things from canonicalizations, this is a regression in functionality, and downstream users who might be implicitly relying on this (cause it is a canonicalization and it kicks in whether you ask for it or not) do not have a way to recover than copying it into their downstream projects.
If we havent decided on a place to put these patterns (maybe a place where all the constant folding patterns get added), we should probably just leave things as is.
@MaheshRavishankar for what it's worth, this pattern was added only 3 days ago (#142458), and so I don't expect anyone is depending on it yet. If anything, it would have broken downstream projects and this would revert the breakage :) That PR was merged before Mehdi had a chance to comment on it, so this is a fix-forward means to address those comments. |
I will give it another hour or so for people to raise any further objections, then merge before the end of my working day. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Copying my comment from #142671 (comment)
It's probably safer to remove the pattern for now, and maybe think about introducing a "tensor-constant-folding" pass separately. I suspect such pass would want options about a maximum size for folding a constant or other option for the strategy to adopt.
We can also think about having these patterns in canonicalization, but limited to "small" constants (the difficulty being how to define how much small is small enough to be universally safe).
Once we get all our patterns working out of tree, I will try adding this pass upstream with some sensible options |
…ization (llvm#142671) Follow ups from llvm#142458 In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large. Signed-off-by: Asra Ali <[email protected]>
…ization (llvm#142671) Follow ups from llvm#142458 In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large. Signed-off-by: Asra Ali <[email protected]>
Follow ups from #142458
In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large.