Skip to content

Commit b975e3b

Browse files
committed
[MLIR] Add canoncalization for shape.is_broadcastable
Canonicalize `is_broadcastable` to constant true if fewer than 2 unique shape operands. Eliminate redundant operands, otherwise. Differential Revision: https://reviews.llvm.org/D98361
1 parent 2224221 commit b975e3b

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,10 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
277277
};
278278
}];
279279

280+
let hasCanonicalizer = 1;
281+
280282
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
281283
let verifier = [{ return ::verify(*this); }];
282-
283284
}
284285

285286
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,44 @@ static LogicalResult verify(IsBroadcastableOp op) {
779779
return success();
780780
}
781781

782+
namespace {
783+
struct IsBroadcastableCanonicalizationPattern
784+
: public OpRewritePattern<IsBroadcastableOp> {
785+
using OpRewritePattern<IsBroadcastableOp>::OpRewritePattern;
786+
787+
LogicalResult matchAndRewrite(IsBroadcastableOp op,
788+
PatternRewriter &rewriter) const override {
789+
// Find unique operands.
790+
SmallVector<Value, 2> unique;
791+
for (Value v : op.getOperands()) {
792+
if (!llvm::is_contained(unique, v))
793+
unique.push_back(v);
794+
}
795+
796+
// Can always broadcast fewer than two shapes.
797+
if (unique.size() < 2) {
798+
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op,
799+
rewriter.getBoolAttr(true));
800+
return success();
801+
}
802+
803+
// Reduce op to equivalent with unique operands.
804+
if (unique.size() < op.getNumOperands()) {
805+
rewriter.replaceOpWithNewOp<IsBroadcastableOp>(op, rewriter.getI1Type(),
806+
unique);
807+
return success();
808+
}
809+
810+
return failure();
811+
}
812+
};
813+
} // namespace
814+
815+
void IsBroadcastableOp::getCanonicalizationPatterns(
816+
OwningRewritePatternList &patterns, MLIRContext *context) {
817+
patterns.insert<IsBroadcastableCanonicalizationPattern>(context);
818+
}
819+
782820
//===----------------------------------------------------------------------===//
783821
// RankOp
784822
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,3 +1069,28 @@ func @fold_tensor.cast_of_const_shape_returned_dynamic(%arg: i1) -> tensor<?xind
10691069
%1 = tensor.cast %0 : tensor<1xindex> to tensor<?xindex>
10701070
return %1 : tensor<?xindex>
10711071
}
1072+
1073+
// -----
1074+
1075+
// CHECK-LABEL: @is_broadcastable_on_same_shape
1076+
func @is_broadcastable_on_same_shape(%shape : !shape.shape) -> i1 {
1077+
// CHECK-NOT: is_broadcastable
1078+
// CHECK: %[[RES:.*]] = constant true
1079+
// CHECK: return %[[RES]]
1080+
%0 = shape.is_broadcastable %shape, %shape, %shape
1081+
: !shape.shape, !shape.shape, !shape.shape
1082+
return %0 : i1
1083+
}
1084+
1085+
// -----
1086+
1087+
// CHECK-LABEL: @is_broadcastable_on_duplicate_shapes
1088+
// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape)
1089+
func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
1090+
-> i1 {
1091+
// CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]]
1092+
// CHECK: return %[[RES]]
1093+
%0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape,
1094+
!shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape
1095+
return %0 : i1
1096+
}

0 commit comments

Comments
 (0)