Skip to content

Commit b1a735b

Browse files
authored
[mlir][math] expand-math pass assumes the static shaped type (#128299)
In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type. Fixes #128275
1 parent e350485 commit b1a735b

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
222222
// if (x > y) then incr = 1 else incr = 0
223223
// y = y + incr <= replace this op with the ceilf op.
224224
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
225+
// Creating constants assumes the static shaped type.
226+
auto shapedType = dyn_cast<ShapedType>(op.getType());
227+
if (shapedType && !shapedType.hasStaticShape())
228+
return failure();
229+
225230
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
226231
Value operand = op.getOperand();
227232
Type opType = operand.getType();

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,3 +761,29 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
761761
%float_result = math.rsqrt %float : tensor<5x8xf32>
762762
return %float_result : tensor<5x8xf32>
763763
}
764+
765+
// -----
766+
767+
// CHECK-LABEL: func.func @non_static_shape_ceil_op
768+
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>)
769+
// CHECK-SAME: -> tensor<?xf32>
770+
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<?xf32>
771+
// CHECK: return %[[CEIL]] : tensor<?xf32>
772+
773+
func.func @non_static_shape_ceil_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
774+
%a = math.ceil %arg : tensor<?xf32>
775+
return %a: tensor<?xf32>
776+
}
777+
778+
// -----
779+
780+
// CHECK-LABEL: func.func @unranked_ceil_op
781+
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
782+
// CHECK-SAME: -> tensor<*xf32>
783+
// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<*xf32>
784+
// CHECK: return %[[CEIL]] : tensor<*xf32>
785+
786+
func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
787+
%a = math.ceil %arg : tensor<*xf32>
788+
return %a: tensor<*xf32>
789+
}

0 commit comments

Comments
 (0)