Skip to content

[mlir][Vector] Add constant folding for vector.from_elements operation #145849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

yangtetris
Copy link
Contributor

Summary

This PR adds a new folding pattern for vector.from_elements that canonicalizes it to arith.constant when all input operands are constants.

Implementation Details

Leverages FoldAdaptor capabilities: Uses adaptor.getElements() to access pre-computed constant attributes, avoiding redundant pattern matching on operands.

Example Transformation

Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Yang Bai (yangtetris)

Changes

Summary

This PR adds a new folding pattern for vector.from_elements that canonicalizes it to arith.constant when all input operands are constants.

Implementation Details

Leverages FoldAdaptor capabilities: Uses adaptor.getElements() to access pre-computed constant attributes, avoiding redundant pattern matching on operands.

Example Transformation

Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector&lt;2x2xi32&gt;

After:
%v = arith.constant dense&lt;[[0, 1], [2, 3]]&gt; : vector&lt;2x2xi32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/145849.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
   return {};
 }
 
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+///   %c1 = arith.constant 1 : i32
+///   %c2 = arith.constant 2 : i32
+///   %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+///   %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+                                               ArrayRef<Attribute> elements) {
+  if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+    return {};
+
+  auto destType = cast<VectorType>(fromElementsOp.getType());
+  return DenseElementsAttr::get(destType, elements);
+}
+
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
-  return foldFromElementsToElements(*this);
+  if (auto res = foldFromElementsToElements(*this))
+    return res;
+  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+    return res;
+
+  return {};
 }
 
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+  %c3_i32 = arith.constant 3 : i32
+  // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+  %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+  // CHECK: return %[[RES]]
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely implemented, well done! LGTM.

@yangtetris
Copy link
Contributor Author

I found that the following MLIR code causes a crash:

%c1 = llvm.mlir.constant(1 : index) : i64
%c2 = llvm.mlir.constant(2 : index) : i64
%res = vector.from_elements %c1, %c2 : vector<2xi64>

The error is:
llvm-project/mlir/lib/IR/BuiltinAttributes.cpp:978: static mlir::DenseElementsAttr mlir::DenseElementsAttr::get(mlir::ShapedType, llvm::ArrayRefmlir::Attribute): Assertion `intAttr.getType() == eltType && "expected integer attribute type to equal element type"' failed.

It seems that this issue is the same as this one. So I committed a change that applies convertIntegerAttr before creating the dense elements attribute.

@Groverkss could you please take another look? Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! Dropped a few comments, thanks!

}

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add a test that crashes without using llvm.mlir.constant ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of difficult for me as I don't know which ops with the constant trait(besides llvm.mlir.constant) allow mismatched attribute and return types. And, I don't think it's worth introducing a new custom op for this test :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants