-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] expand-math pass assumes the static shaped type #128299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][math] expand-math pass assumes the static shaped type #128299
Conversation
In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type. Fixes llvm#128275
@llvm/pr-subscribers-mlir-cf @llvm/pr-subscribers-mlir Author: Kai Sasaki (Lewuathe) ChangesIn the process of Fixes #128275 Full diff: https://github.com/llvm/llvm-project/pull/128299.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
// if (x > y) then incr = 1 else incr = 0
// y = y + incr <= replace this op with the ceilf op.
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+ // Creating constants assumes the statis shaped type.
+ auto shapedType = dyn_cast<ShapedType>(op.getType());
+ if (shapedType && !shapedType.hasStaticShape())
+ return failure();
+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
%float_result = math.rsqrt %float : tensor<5x8xf32>
return %float_result : tensor<5x8xf32>
}
+
+// -----
+
+// CHECK-LABEL func.func @non_static_shape_ceil_op
+// CHECK: %[[IDX:.*]] = index.constant 0
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK: vector.print %[[DIM]] : index
+// CHECK: return
+
+func.func @non_static_shape_ceil_op() {
+ %idx0 = index.constant 0
+ %cst_90 = arith.constant 1.000000e+00 : f32
+ %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+ %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+ %112 = math.ceil %cast_93 : tensor<?xf32>
+ %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+ vector.print %dim_233 : index
+ return
+}
|
e01b883
to
a436c01
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! You may likewise also want to test unranked types:
func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
%a = math.ceil %arg : tensor<*xf32>
return %a: tensor<*xf32>
}
Also I see a constant created in rsqrt
and we don't check for static types. I guess the same issue would apply there?
Thanks for reporting that part. Let me take a look into the issue and will work on in another pull request if necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you for this fix!
…29006) Similar to the issue reported in #128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation. See: #128299
…operand (#129006) Similar to the issue reported in llvm/llvm-project#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation. See: llvm/llvm-project#128299
…vm#129006) Similar to the issue reported in llvm#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation. See: llvm#128299
In the process of
expand-math
pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type.Fixes #128275