Skip to content

Commit 55f2547

Browse files
authored
[mlir][math] Rsqrt math expand pass expects static shaped operand (#129006)
Similar to the issue reported in #128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation. See: #128299
1 parent 746d8b0 commit 55f2547

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
646646

647647
auto operand = op.getOperand();
648648
auto operandTy = operand.getType();
649+
// Operand type must be shatic shaped type to create const float.
650+
auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
651+
if (shapedOperandType && !shapedOperandType.hasStaticShape())
652+
return failure();
653+
649654
auto eTy = getElementTypeOrSelf(operandTy);
650655
if (!isa<FloatType>(eTy))
651656
return failure();

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
787787
%a = math.ceil %arg : tensor<*xf32>
788788
return %a: tensor<*xf32>
789789
}
790+
791+
// -----
792+
793+
// CHECK-LABEL: func.func @non_static_shape_rsqrt_op
794+
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>)
795+
// CHECK-SAME: -> tensor<?xf32>
796+
// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<?xf32>
797+
// CHECK: return %[[RSQRT]] : tensor<?xf32>
798+
799+
func.func @non_static_shape_rsqrt_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
800+
%a = math.rsqrt %arg : tensor<?xf32>
801+
return %a: tensor<?xf32>
802+
}
803+
804+
// -----
805+
806+
// CHECK-LABEL: func.func @unranked_rsqrt_op
807+
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
808+
// CHECK-SAME: -> tensor<*xf32>
809+
// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32>
810+
// CHECK: return %[[RSQRT]] : tensor<*xf32>
811+
812+
func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
813+
%a = math.rsqrt %arg : tensor<*xf32>
814+
return %a: tensor<*xf32>
815+
}

0 commit comments

Comments
 (0)