Skip to content

[mlir][math] expand-math pass assumes the static shaped type #128299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Lewuathe
Copy link
Member

In the process of expand-math pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type.

Fixes #128275

In the process of `expand-math` pass, the conversion of ceil op assumes
the static shaped type as input as it needs create 0 and 1 constant
values whose type is aligned with the op type.

Fixes llvm#128275
@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2025

@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Kai Sasaki (Lewuathe)

Changes

In the process of expand-math pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type.

Fixes #128275


Full diff: https://github.com/llvm/llvm-project/pull/128299.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+5)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+22)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 23356d752146d..67e8dbba989b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
 //      if (x > y) then incr = 1 else incr = 0
 //      y = y + incr   <= replace this op with the ceilf op.
 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+  // Creating constants assumes the statis shaped type.
+  auto shapedType = dyn_cast<ShapedType>(op.getType());
+  if (shapedType && !shapedType.hasStaticShape())
+    return failure();
+
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1fdfb854325b4..4e249ec510afa 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
   %float_result = math.rsqrt %float : tensor<5x8xf32>
   return %float_result : tensor<5x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL  func.func @non_static_shape_ceil_op
+// CHECK:     %[[IDX:.*]] = index.constant 0
+// CHECK:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
+// CHECK:     %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
+// CHECK:     %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
+// CHECK:     %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
+// CHECK:     vector.print %[[DIM]] : index
+// CHECK:         return
+
+func.func @non_static_shape_ceil_op() {
+  %idx0 = index.constant 0
+  %cst_90 = arith.constant 1.000000e+00 : f32
+  %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
+  %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
+  %112 = math.ceil %cast_93 : tensor<?xf32>
+  %dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
+  vector.print %dim_233 : index
+  return
+}

@Lewuathe Lewuathe force-pushed the convert-ceil-op-assume-static-shaped-type branch from e01b883 to a436c01 Compare February 23, 2025 05:39
Copy link
Contributor

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! You may likewise also want to test unranked types:

func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
  %a = math.ceil %arg : tensor<*xf32>
  return %a: tensor<*xf32>
}

Also I see a constant created in rsqrt and we don't check for static types. I guess the same issue would apply there?

@Lewuathe
Copy link
Member Author

Also I see a constant created in rsqrt and we don't check for static types. I guess the same issue would apply there?

Thanks for reporting that part. Let me take a look into the issue and will work on in another pull request if necessary.

@Lewuathe Lewuathe requested a review from cferry-AMD February 25, 2025 00:59
Copy link
Contributor

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for this fix!

@Lewuathe Lewuathe merged commit b1a735b into llvm:main Feb 26, 2025
11 checks passed
@Lewuathe Lewuathe deleted the convert-ceil-op-assume-static-shaped-type branch February 26, 2025 01:30
Lewuathe added a commit that referenced this pull request Feb 28, 2025
…29006)

Similar to the issue reported in

#128299 (review),
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: #128299
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 28, 2025
…operand (#129006)

Similar to the issue reported in

llvm/llvm-project#128299 (review),
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: llvm/llvm-project#128299
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
…vm#129006)

Similar to the issue reported in

llvm#128299 (review),
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: llvm#128299
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Crash when using --test-expand-math
5 participants