Skip to content

Commit 655e08c

Browse files
committed
[mlir] Canonicalization of shape.assuming
Summary: This will inline the region to a shape.assuming in the case that the input witness is found to be statically true. Differential Revision: https://reviews.llvm.org/D80302
1 parent 0a554e6 commit 655e08c

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ def Shape_AssumingOp : Shape_Op<"assuming",
509509

510510
let printer = [{ return ::print(p, *this); }];
511511
let parser = [{ return ::parse$cppClass(parser, result); }];
512+
513+
let hasCanonicalizer = 1;
512514
}
513515

514516
def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,44 @@ static void print(OpAsmPrinter &p, AssumingOp op) {
159159
p.printOptionalAttrDict(op.getAttrs());
160160
}
161161

162+
namespace {
163+
// Removes AssumingOp with a passing witness and inlines the region.
164+
struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
165+
using OpRewritePattern<AssumingOp>::OpRewritePattern;
166+
167+
LogicalResult matchAndRewrite(AssumingOp op,
168+
PatternRewriter &rewriter) const override {
169+
auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
170+
if (!witness || !witness.passingAttr())
171+
return failure();
172+
173+
auto *blockBeforeAssuming = rewriter.getInsertionBlock();
174+
auto *assumingBlock = op.getBody();
175+
auto initPosition = rewriter.getInsertionPoint();
176+
auto *blockAfterAssuming =
177+
rewriter.splitBlock(blockBeforeAssuming, initPosition);
178+
179+
// Remove the AssumingOp and AssumingYieldOp.
180+
auto &yieldOp = assumingBlock->back();
181+
rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
182+
rewriter.replaceOp(op, yieldOp.getOperands());
183+
rewriter.eraseOp(&yieldOp);
184+
185+
// Merge blocks together as there was no branching behavior from the
186+
// AssumingOp.
187+
rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
188+
rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
189+
return success();
190+
}
191+
};
192+
}; // namespace
193+
194+
void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
195+
MLIRContext *context) {
196+
// If taking a passing witness, inline region
197+
patterns.insert<AssumingWithTrue>(context);
198+
}
199+
162200
//===----------------------------------------------------------------------===//
163201
// AssumingAllOp
164202
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,42 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
324324
return %1 : !shape.shape
325325
}
326326

327+
// -----
328+
// assuming with a known passing witness can be removed
329+
// CHECK-LABEL: func @f
330+
func @f() {
331+
// CHECK-NEXT: source
332+
// CHECK-NEXT: sink
333+
// CHECK-NEXT: return
334+
%0 = shape.const_witness true
335+
%1 = shape.assuming %0 -> index {
336+
%2 = "test.source"() : () -> (index)
337+
shape.assuming_yield %2 : index
338+
}
339+
"test.sink"(%1) : (index) -> ()
340+
return
341+
}
342+
343+
// -----
344+
// assuming without a known passing passing witness cannot be removed
345+
// CHECK-LABEL: func @f
346+
func @f() {
347+
// CHECK-NEXT: test.source
348+
// CHECK-NEXT: shape.assuming
349+
// CHECK-NEXT: test.source
350+
// CHECK-NEXT: shape.assuming_yield
351+
// CHECK-NEXT: }
352+
// CHECK-NEXT: test.sink
353+
// CHECK-NEXT: return
354+
%0 = "test.source"() : () -> (!shape.witness)
355+
%1 = shape.assuming %0 -> index {
356+
%2 = "test.source"() : () -> (index)
357+
shape.assuming_yield %2 : index
358+
}
359+
"test.sink"(%1) : (index) -> ()
360+
return
361+
}
362+
327363
// -----
328364
// Broadcastable with broadcastable constant shapes can be removed.
329365
// CHECK-LABEL: func @f

0 commit comments

Comments
 (0)