-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Fixes arith.sub folder crash on dynamically shaped tensors #118908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesWe can't create a constant for a value with dynamic shape. Fixes #118772 Full diff: https://github.com/llvm/llvm-project/pull/118908.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f445231b80fdf..5a79add16569f4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
// subi(x,x) -> 0
- if (getOperand(0) == getOperand(1))
- return Builder(getContext()).getZeroAttr(getType());
+ if (getOperand(0) == getOperand(1)) {
+ auto tensorType = dyn_cast<TensorType>(getType());
+ // We can't generate a constant with a dynamic shaped tensor.
+ if (!tensorType || tensorType.hasStaticShape())
+ return Builder(getContext()).getZeroAttr(getType());
+ }
// subi(x,0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 69df83d42f543e..8c22506b9afcf1 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -869,6 +869,17 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
return %add2 : index
}
+
+// CHECK-LABEL: @foldSubXX
+// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[sub:.+]] = arith.subi
+// CHECK: return %[[c0]], %[[sub]]
+func.func @foldSubXX(%dyn : tensor<?x?xi32>, %static : tensor<10xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
+ %static_sub = arith.subi %static, %static : tensor<10xi32>
+ %dyn_sub = arith.subi %dyn, %dyn : tensor<?x?xi32>
+ return %static_sub, %dyn_sub : tensor<10xi32>, tensor<?x?xi32>
+}
+
// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
|
if (getOperand(0) == getOperand(1)) | ||
return Builder(getContext()).getZeroAttr(getType()); | ||
if (getOperand(0) == getOperand(1)) { | ||
auto tensorType = dyn_cast<TensorType>(getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe check for ShapedType
instead? In case we decide to introduce dynamically-sized vectors one day.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we already kinda have them with scalabale vectors. Can you add test for them too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point for using ShapedType.
The scalable vectors aren't an issue I believe since you can create a constant. They act as if they are statically shaped from this point of view.
469ab15
to
02793a7
Compare
We can't create a constant for a value with dynamic shape. Fixes llvm#118772
02793a7
to
57bab05
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
We can't create a constant for a value with dynamic shape.
Fixes #118772