Skip to content

Commit 0a554e6

Browse files
committed
[mlir] Folding and canonicalization of shape.cstr_eq
In the case of all inputs being constant and equal, cstr_eq will be replaced with a true_witness. Differential Revision: https://reviews.llvm.org/D80303
1 parent 6aab709 commit 0a554e6

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
555555
let hasFolder = 1;
556556
}
557557

558-
def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
558+
def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
559559
let summary = "Determines if all input shapes are equal";
560560
let description = [{
561561
Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -572,6 +572,9 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
572572
let results = (outs Shape_WitnessType:$result);
573573

574574
let assemblyFormat = "$inputs attr-dict";
575+
576+
let hasCanonicalizer = 1;
577+
let hasFolder = 1;
575578
}
576579

577580
def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,27 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
290290
return nullptr;
291291
}
292292

293+
//===----------------------------------------------------------------------===//
294+
// CstrEqOp
295+
//===----------------------------------------------------------------------===//
296+
297+
void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
298+
MLIRContext *context) {
299+
// If inputs are equal, return passing witness
300+
patterns.insert<CstrEqEqOps>(context);
301+
}
302+
303+
OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
304+
if (llvm::all_of(operands,
305+
[&](Attribute a) { return a && a == operands[0]; }))
306+
return BoolAttr::get(true, getContext());
307+
308+
// Because a failing witness result here represents an eventual assertion
309+
// failure, we do not try to replace it with a constant witness. Similarly, we
310+
// cannot if there are any non-const inputs.
311+
return nullptr;
312+
}
313+
293314
//===----------------------------------------------------------------------===//
294315
// ConstSizeOp
295316
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td"
22

33
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
44

5+
def AllInputShapesEq : Constraint<CPred< [{
6+
llvm::all_of($0, [&](mlir::Value val) {
7+
return $0[0] == val;
8+
})
9+
}]>>;
10+
511
// Canonicalization patterns.
612
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
713
(Shape_ConstWitnessOp ConstBoolAttrTrue),
814
[(EqualBinaryOperands $lhs, $rhs)]>;
15+
16+
def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
17+
(Shape_ConstWitnessOp ConstBoolAttrTrue),
18+
[(AllInputShapesEq $shapes)]>;

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,62 @@ func @not_const(%arg0: !shape.shape) -> !shape.size {
213213
return %0 : !shape.size
214214
}
215215

216+
217+
// -----
218+
// cstr_eq with non-constant but known equal shapes can be removed.
219+
// CHECK-LABEL: func @f
220+
func @f(%arg0 : !shape.shape) {
221+
// CHECK-NEXT: shape.const_witness true
222+
// CHECK-NEXT: consume.witness
223+
// CHECK-NEXT: return
224+
%0 = shape.cstr_eq %arg0, %arg0, %arg0
225+
"consume.witness"(%0) : (!shape.witness) -> ()
226+
return
227+
}
228+
229+
// -----
230+
// cstr_eq with equal const_shapes can be folded
231+
// CHECK-LABEL: func @f
232+
func @f() {
233+
// CHECK-NEXT: shape.const_witness true
234+
// CHECK-NEXT: consume.witness
235+
// CHECK-NEXT: return
236+
%cs0 = shape.const_shape [0, 1]
237+
%cs1 = shape.const_shape [0, 1]
238+
%cs2 = shape.const_shape [0, 1]
239+
%0 = shape.cstr_eq %cs0, %cs1, %cs2
240+
"consume.witness"(%0) : (!shape.witness) -> ()
241+
return
242+
}
243+
244+
// -----
245+
// cstr_eq with unequal const_shapes cannot be folded
246+
// CHECK-LABEL: func @f
247+
func @f() {
248+
// CHECK-NEXT: shape.const_shape
249+
// CHECK-NEXT: shape.const_shape
250+
// CHECK-NEXT: shape.cstr_eq
251+
// CHECK-NEXT: consume.witness
252+
// CHECK-NEXT: return
253+
%cs0 = shape.const_shape [0, 1]
254+
%cs1 = shape.const_shape [3, 1]
255+
%0 = shape.cstr_eq %cs0, %cs1
256+
"consume.witness"(%0) : (!shape.witness) -> ()
257+
return
258+
}
259+
260+
// -----
261+
// cstr_eq without const_shapes cannot be folded
262+
// CHECK-LABEL: func @f
263+
func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
264+
// CHECK-NEXT: shape.cstr_eq
265+
// CHECK-NEXT: consume.witness
266+
// CHECK-NEXT: return
267+
%0 = shape.cstr_eq %arg0, %arg1
268+
"consume.witness"(%0) : (!shape.witness) -> ()
269+
return
270+
}
271+
216272
// -----
217273
// assuming_all with known passing witnesses can be folded
218274
// CHECK-LABEL: func @f

0 commit comments

Comments
 (0)