Skip to content

Commit 3624e3a

Browse files
committed
[mlir][math] expand-math pass assumes the static shaped type
In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type. Fixes #128275
1 parent 60f3fdd commit 3624e3a

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
222222
// if (x > y) then incr = 1 else incr = 0
223223
// y = y + incr <= replace this op with the ceilf op.
224224
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
225+
// Creating constants assumes the statis shaped type.
226+
auto shapedType = dyn_cast<ShapedType>(op.getType());
227+
if (shapedType && !shapedType.hasStaticShape())
228+
return failure();
229+
225230
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
226231
Value operand = op.getOperand();
227232
Type opType = operand.getType();

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
761761
%float_result = math.rsqrt %float : tensor<5x8xf32>
762762
return %float_result : tensor<5x8xf32>
763763
}
764+
765+
// -----
766+
767+
// CHECK-LABEL func.func @non_static_shape_ceil_op
768+
// CHECK: %[[IDX:.*]] = index.constant 0
769+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32>
770+
// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor<?xf32>
771+
// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor<?xf32>
772+
// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor<?xf32>
773+
// CHECK: vector.print %[[DIM]] : index
774+
// CHECK: return
775+
776+
func.func @non_static_shape_ceil_op() {
777+
%idx0 = index.constant 0
778+
%cst_90 = arith.constant 1.000000e+00 : f32
779+
%from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32>
780+
%cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor<?xf32>
781+
%112 = math.ceil %cast_93 : tensor<?xf32>
782+
%dim_233 = tensor.dim %112, %idx0 : tensor<?xf32>
783+
vector.print %dim_233 : index
784+
return
785+
}

0 commit comments

Comments
 (0)