Skip to content

Commit 0375983

Browse files
committed
[MLIR] fix shape.broadcast canonicalize all empty shape operands
1 parent 6c4e70f commit 0375983

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,14 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
698698
auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
699699
isPotentiallyNonEmptyShape);
700700

701+
// Replace the op with empty shape constant if all operants are reduced to
702+
// be empty.
703+
if (newOperands.empty()) {
704+
rewriter.replaceOpWithNewOp<ConstShapeOp>(
705+
op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
706+
return success();
707+
}
708+
701709
// Reduce op to equivalent without empty shape operands.
702710
if (newOperands.size() < op.getNumOperands()) {
703711
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s
2+
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=0" %s | FileCheck %s
23

34
// CHECK-LABEL: func @f
45
func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> {
@@ -134,6 +135,21 @@ func.func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape {
134135

135136
// -----
136137

138+
// All operands are known empty shapes.
139+
// CHECK-LABEL: @all_empty
140+
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<i1>)
141+
func.func @all_empty(%arg0: tensor<f32>, %arg1: tensor<i1>) -> tensor<0xindex> {
142+
// CHECK: %[[CST:.*]] = shape.const_shape [] : tensor<0xindex>
143+
// CHECK: return %[[CST]] : tensor<0xindex>
144+
%1 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
145+
%2 = shape.shape_of %arg1 : tensor<i1> -> tensor<0xindex>
146+
%3 = shape.const_shape [] : tensor<0xindex>
147+
%4 = shape.broadcast %1, %2, %3 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex>
148+
return %4 : tensor<0xindex>
149+
}
150+
151+
// -----
152+
137153
// Partial folding.
138154
// CHECK-LABEL: @partial_folding
139155
// CHECK-SAME: (%[[ARG:.*]]: !shape.shape)

0 commit comments

Comments
 (0)