Skip to content

Commit e124e18

Browse files
committed
[MLIR] fix shape.broadcast canonicalize all empty shape operands
1 parent 398f3b3 commit e124e18

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,12 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
699699
isPotentiallyNonEmptyShape);
700700

701701
// Reduce op to equivalent without empty shape operands.
702+
if (newOperands.empty()) {
703+
rewriter.replaceOpWithNewOp<ConstShapeOp>(
704+
op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
705+
return success();
706+
}
707+
702708
if (newOperands.size() < op.getNumOperands()) {
703709
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
704710
op->getAttrs());

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s
2+
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=1" %s | FileCheck %s
23

34
// CHECK-LABEL: func @f
45
func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
@@ -134,6 +135,21 @@ func.func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
134135

135136
// -----
136137

138+
// All operands are known empty shapes.
139+
// CHECK-LABEL: @all_empty
140+
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<i1>)
141+
func.func @all_empty(%arg0: tensor<f32>, %arg1: tensor<i1>) -> tensor<0xindex> {
142+
// CHECK: %[[CST:.*]] = shape.const_shape [] : tensor<0xindex>
143+
// CHECK: return %[[CST]] : tensor<0xindex>
144+
%1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
145+
%2 = shape.shape_of %arg1 : tensor<i1> -> tensor<0xindex>
146+
%3 = shape.const_shape [] : tensor<0xindex>
147+
%4 = shape.broadcast %1, %2, %3 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
148+
return %4 : tensor<0xindex>
149+
}
150+
151+
// -----
152+
137153
// Partial folding.
138154
// CHECK-LABEL: @partial_folding
139155
// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)

0 commit comments

Comments
 (0)