Skip to content

[mlir][math] Rsqrt math expand pass expects static shaped operand #129006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2025

Conversation

Lewuathe
Copy link
Member

Similar to the issue reported in
#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation.

See: #128299

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Kai Sasaki (Lewuathe)

Changes

Similar to the issue reported in
#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation.

See: #128299


Full diff: https://github.com/llvm/llvm-project/pull/129006.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+5)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+26)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bb592c667549c..7b5350ca26b60 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
 
   auto operand = op.getOperand();
   auto operandTy = operand.getType();
+  // Operand type must be shatic shaped type to create const float.
+  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
+  if (shapedOperandType && !shapedOperandType.hasStaticShape())
+    return failure();
+
   auto eTy = getElementTypeOrSelf(operandTy);
   if (!isa<FloatType>(eTy))
     return failure();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 946a411e4cc4b..8743efec5ecb4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
   %a = math.ceil %arg : tensor<*xf32>
   return %a: tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @non_static_shape_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME:     -> tensor<?xf32>
+// CHECK:          %[[CEIL:.*]] = math.rsqrt %[[ARG]] : tensor<?xf32>
+// CHECK:          return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_rsqrt_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+  %a = math.rsqrt %arg : tensor<?xf32>
+  return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:    func.func @unranked_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME:     -> tensor<*xf32>
+// CHECK:          %[[CEIL:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32>
+// CHECK:          return %[[CEIL]] : tensor<*xf32>
+
+func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+  %a = math.rsqrt %arg : tensor<*xf32>
+  return %a: tensor<*xf32>
+}

Similar to the issue reported in
https://github.com/llvm/llvm-project/pull/128299/files, ExpandMath
pattern for rsqrt expects the static shaped operands. Otherwise, it
crashes due to the assertion violation.
@Lewuathe Lewuathe force-pushed the rsqrt-in-expand-math-expect-static-shape branch from 4d899de to d01daf2 Compare February 27, 2025 06:02
Copy link
Contributor

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was fast. Thanks once more!

@Lewuathe Lewuathe merged commit 55f2547 into llvm:main Feb 28, 2025
11 checks passed
@Lewuathe Lewuathe deleted the rsqrt-in-expand-math-expect-static-shape branch February 28, 2025 04:37
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
…vm#129006)

Similar to the issue reported in

llvm#128299 (review),
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: llvm#128299
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants