Skip to content

[mlir][tensor] remove tensor.insert constant folding out of canonicalization #142671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025

Conversation

asraa
Copy link
Contributor

@asraa asraa commented Jun 3, 2025

Follow ups from #142458
In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large.

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (asraa)

Changes

Follow ups from #142458


Full diff: https://github.com/llvm/llvm-project/pull/142671.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/Tensor.h (+3)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3-4)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (-16)
  • (added) mlir/test/Dialect/Tensor/insert-after-constant.mlir (+14)
  • (modified) mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (+13)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index e8e1342ef36fd..447f5b906cad1 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -176,6 +176,9 @@ void populateFoldConstantExtractSlicePatterns(
           return false;
         });
 
+/// Patterns to fold inserts into a constant into a new constant.
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns);
+
 /// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
 /// source tensor.
 void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index c0885a3763827..35d0b16628417 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 12e8b257ce9f1..c051ef4ae6bb0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1653,7 +1653,7 @@ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
       return failure();
 
     // Pattern requires constant indices
-    SmallVector<uint64_t, 8> indices;
+    SmallVector<uint64_t> indices;
     for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
       auto indiceAttr = dyn_cast<Attribute>(indice);
       if (!indiceAttr)
@@ -1717,9 +1717,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
-void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                           MLIRContext *context) {
-  results.add<InsertOpConstantFold>(context);
+void populateFoldInsertAfterConstant(RewritePatternSet &patterns) {
+  patterns.add<InsertOpConstantFold>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..f033a43c0dc24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
   return %ins_1 : tensor<4xf32>
 }
 
-
-// -----
-
-func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
-  // Fold an insert into a splat.
-  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
-  // CHECK-LITERAL:
-  // CHECK-NEXT: return %[[C4]]
-  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %c4_i32 = arith.constant 4 : i32
-  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
-  return %inserted : tensor<2x2xi32>
-}
-
 // -----
 
 // CHECK-LABEL: func @extract_from_tensor.cast
diff --git a/mlir/test/Dialect/Tensor/insert-after-constant.mlir b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
new file mode 100644
index 0000000000000..73f49ac6eba78
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/insert-after-constant.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-insert-after-constant %s | FileCheck %s
+
+func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
+  // Fold an insert into a splat.
+  // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
+  // CHECK-LITERAL:
+  // CHECK-NEXT: return %[[C4]]
+  %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4_i32 = arith.constant 4 : i32
+  %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
+  return %inserted : tensor<2x2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 0e191c32f009e..f2750c5e9a0de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -82,6 +82,11 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding of extract from collapse_shape"),
       llvm::cl::init(false)};
 
+  Option<bool> testFoldInsertAfterConstant{
+      *this, "test-fold-insert-after-constant",
+      llvm::cl::desc("Test folding of insert of a constant"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -143,6 +148,12 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldInsertAfterConstantPattern(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateFoldInsertAfterConstant(patterns);
+  (void)applyPatternsGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -393,6 +404,8 @@ void TestTensorTransforms::runOnOperation() {
   }
   if (testFoldExtractFromCollapseShape)
     applyFoldExtractFromCollapseShapePatterns(rootOp);
+  if (testFoldInsertAfterConstant)
+    applyFoldInsertAfterConstantPattern(rootOp);
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();

@asraa asraa force-pushed the tensor-canon-fixes branch from b67677f to a82d723 Compare June 3, 2025 21:49
@asraa asraa force-pushed the tensor-canon-fixes branch from a82d723 to 50dee1f Compare June 3, 2025 22:07
@asraa
Copy link
Contributor Author

asraa commented Jun 3, 2025

cc @joker-eph this removes the pattern from canonicalization and fixes those comments from the PR

Copy link

github-actions bot commented Jun 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@joker-eph joker-eph changed the title [mlir][tensor] move tensor insert canonicalization to pattern [mlir][tensor] move tensor.insert constant folding out of canonicalization Jun 4, 2025
Signed-off-by: Asra Ali <[email protected]>

fix

Signed-off-by: Asra Ali <[email protected]>

fix loop

Signed-off-by: Asra Ali <[email protected]>

remove pattern

Signed-off-by: Asra Ali <[email protected]>
@asraa asraa force-pushed the tensor-canon-fixes branch from 50dee1f to 8720026 Compare June 5, 2025 15:27
@asraa asraa changed the title [mlir][tensor] move tensor.insert constant folding out of canonicalization [mlir][tensor] remove tensor.insert constant folding out of canonicalization Jun 5, 2025
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to add to the trouble here. While I am all for removing things from canonicalizations, this is a regression in functionality, and downstream users who might be implicitly relying on this (cause it is a canonicalization and it kicks in whether you ask for it or not) do not have a way to recover than copying it into their downstream projects.

If we havent decided on a place to put these patterns (maybe a place where all the constant folding patterns get added), we should probably just leave things as is.

@j2kun
Copy link
Contributor

j2kun commented Jun 5, 2025

@MaheshRavishankar for what it's worth, this pattern was added only 3 days ago (#142458), and so I don't expect anyone is depending on it yet. If anything, it would have broken downstream projects and this would revert the breakage :) That PR was merged before Mehdi had a chance to comment on it, so this is a fix-forward means to address those comments.

@j2kun
Copy link
Contributor

j2kun commented Jun 5, 2025

I will give it another hour or so for people to raise any further objections, then merge before the end of my working day.

@MaheshRavishankar MaheshRavishankar dismissed their stale review June 5, 2025 21:20

My concern was addressed. Thanks!

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copying my comment from #142671 (comment)

It's probably safer to remove the pattern for now, and maybe think about introducing a "tensor-constant-folding" pass separately. I suspect such pass would want options about a maximum size for folding a constant or other option for the strategy to adopt.
We can also think about having these patterns in canonicalization, but limited to "small" constants (the difficulty being how to define how much small is small enough to be universally safe).

@j2kun
Copy link
Contributor

j2kun commented Jun 5, 2025

Once we get all our patterns working out of tree, I will try adding this pass upstream with some sensible options

@j2kun j2kun merged commit c66b72f into llvm:main Jun 5, 2025
11 checks passed
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
…ization (llvm#142671)

Follow ups from llvm#142458
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.

Signed-off-by: Asra Ali <[email protected]>
DhruvSrivastavaX pushed a commit to DhruvSrivastavaX/lldb-for-aix that referenced this pull request Jun 12, 2025
…ization (llvm#142671)

Follow ups from llvm#142458
In particular concerns that indiscriminately folding tensor constants
can lead to bloating the IR as these can be arbitrarily large.

Signed-off-by: Asra Ali <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants