Skip to content

Commit bd1ccfe

Browse files
committed
[mlir] Add a new RewritePattern::hasBoundedRewriteRecursion hook.
Summary: Some pattern rewriters, like dialect conversion, prohibit the unbounded recursion(or reapplication) of patterns on generated IR. Most patterns are not written with recursive application in mind, so will generally explode the stack if uncaught. This revision adds a hook to RewritePattern, `hasBoundedRewriteRecursion`, to signal that the pattern can safely be applied to the generated IR of a previous application of the same pattern. This allows for establishing a contract between the pattern and rewriter that the pattern knows and can handle the potential recursive application. Differential Revision: https://reviews.llvm.org/D77782
1 parent b96558f commit bd1ccfe

File tree

6 files changed

+60
-35
lines changed

6 files changed

+60
-35
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ class RewritePattern : public Pattern {
131131
return failure();
132132
}
133133

134+
/// Returns true if this pattern is known to result in recursive application,
135+
/// i.e. this pattern may generate IR that also matches this pattern, but is
136+
/// known to bound the recursion. This signals to a rewriter that it is safe
137+
/// to apply this pattern recursively to generated IR.
138+
virtual bool hasBoundedRewriteRecursion() const { return false; }
139+
134140
/// Return a list of operations that may be generated when rewriting an
135141
/// operation instance with this pattern.
136142
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -789,23 +789,10 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
789789
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
790790
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
791791
// smaller rank.
792-
InsertStridedSliceOp insertStridedSliceOp =
793-
rewriter.create<InsertStridedSliceOp>(
794-
loc, extractedSource, extractedDest,
795-
getI64SubArray(op.offsets(), /* dropFront=*/1),
796-
getI64SubArray(op.strides(), /* dropFront=*/1));
797-
// Call matchAndRewrite recursively from within the pattern. This
798-
// circumvents the current limitation that a given pattern cannot
799-
// be called multiple times by the PatternRewrite infrastructure (to
800-
// avoid infinite recursion, but in this case, infinite recursion
801-
// cannot happen because the rank is strictly decreasing).
802-
// TODO(rriddle, nicolasvasilache) Implement something like a hook for
803-
// a potential function that must decrease and allow the same pattern
804-
// multiple times.
805-
auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
806-
(void)success;
807-
assert(succeeded(success) && "Unexpected failure");
808-
extractedSource = insertStridedSliceOp;
792+
extractedSource = rewriter.create<InsertStridedSliceOp>(
793+
loc, extractedSource, extractedDest,
794+
getI64SubArray(op.offsets(), /* dropFront=*/1),
795+
getI64SubArray(op.strides(), /* dropFront=*/1));
809796
}
810797
// 4. Insert the extractedSource into the res vector.
811798
res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -814,6 +801,9 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
814801
rewriter.replaceOp(op, res);
815802
return success();
816803
}
804+
/// This pattern creates recursive InsertStridedSliceOp, but the recursion is
805+
/// bounded as the rank is strictly decreasing.
806+
bool hasBoundedRewriteRecursion() const final { return true; }
817807
};
818808

819809
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
@@ -1068,28 +1058,19 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
10681058
off += stride, ++idx) {
10691059
Value extracted = extractOne(rewriter, loc, op.vector(), off);
10701060
if (op.offsets().getValue().size() > 1) {
1071-
StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
1061+
extracted = rewriter.create<StridedSliceOp>(
10721062
loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
10731063
getI64SubArray(op.sizes(), /* dropFront=*/1),
10741064
getI64SubArray(op.strides(), /* dropFront=*/1));
1075-
// Call matchAndRewrite recursively from within the pattern. This
1076-
// circumvents the current limitation that a given pattern cannot
1077-
// be called multiple times by the PatternRewrite infrastructure (to
1078-
// avoid infinite recursion, but in this case, infinite recursion
1079-
// cannot happen because the rank is strictly decreasing).
1080-
// TODO(rriddle, nicolasvasilache) Implement something like a hook for
1081-
// a potential function that must decrease and allow the same pattern
1082-
// multiple times.
1083-
auto success = matchAndRewrite(stridedSliceOp, rewriter);
1084-
(void)success;
1085-
assert(succeeded(success) && "Unexpected failure");
1086-
extracted = stridedSliceOp;
10871065
}
10881066
res = insertOne(rewriter, loc, extracted, res, idx);
10891067
}
10901068
rewriter.replaceOp(op, {res});
10911069
return success();
10921070
}
1071+
/// This pattern creates recursive StridedSliceOp, but the recursion is
1072+
/// bounded as the rank is strictly decreasing.
1073+
bool hasBoundedRewriteRecursion() const final { return true; }
10931074
};
10941075

10951076
} // namespace

mlir/lib/Transforms/DialectConversion.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,10 +1256,9 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
12561256
});
12571257

12581258
// Ensure that we don't cycle by not allowing the same pattern to be
1259-
// applied twice in the same recursion stack.
1260-
// TODO(riverriddle) We could eventually converge, but that requires more
1261-
// complicated analysis.
1262-
if (!appliedPatterns.insert(pattern).second) {
1259+
// applied twice in the same recursion stack if it is not known to be safe.
1260+
if (!pattern->hasBoundedRewriteRecursion() &&
1261+
!appliedPatterns.insert(pattern).second) {
12631262
LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied"));
12641263
return failure();
12651264
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ func @create_block() {
143143
return
144144
}
145145

146+
// CHECK-LABEL: @bounded_recursion
147+
func @bounded_recursion() {
148+
// CHECK: test.recursive_rewrite 0
149+
test.recursive_rewrite 3
150+
return
151+
}
152+
146153
// -----
147154

148155
func @fail_to_convert_illegal_op() -> i32 {

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,12 @@ def TestRewriteOp : TEST_Op<"rewrite">,
10611061
Arguments<(ins AnyType)>, Results<(outs AnyType)>;
10621062
def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
10631063

1064+
// Check that patterns can specify bounded recursion when rewriting.
1065+
def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
1066+
let arguments = (ins I64Attr:$depth);
1067+
let assemblyFormat = "$depth attr-dict";
1068+
}
1069+
10641070
//===----------------------------------------------------------------------===//
10651071
// Test Type Legalization
10661072
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,28 @@ struct TestNonRootReplacement : public RewritePattern {
360360
return success();
361361
}
362362
};
363+
364+
//===----------------------------------------------------------------------===//
365+
// Recursive Rewrite Testing
366+
/// This pattern is applied to the same operation multiple times, but has a
367+
/// bounded recursion.
368+
struct TestBoundedRecursiveRewrite
369+
: public OpRewritePattern<TestRecursiveRewriteOp> {
370+
using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
371+
372+
LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
373+
PatternRewriter &rewriter) const final {
374+
// Decrement the depth of the op in-place.
375+
rewriter.updateRootInPlace(op, [&] {
376+
op.setAttr("depth",
377+
rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
378+
});
379+
return success();
380+
}
381+
382+
/// The conversion target handles bounding the recursion of this pattern.
383+
bool hasBoundedRewriteRecursion() const final { return true; }
384+
};
363385
} // namespace
364386

365387
namespace {
@@ -414,7 +436,7 @@ struct TestLegalizePatternDriver
414436
TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
415437
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
416438
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
417-
TestNonRootReplacement>(&getContext());
439+
TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
418440
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
419441
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
420442
converter);
@@ -449,6 +471,10 @@ struct TestLegalizePatternDriver
449471
op->getAttrOfType<UnitAttr>("test.recursively_legal"));
450472
});
451473

474+
// Mark the bound recursion operation as dynamically legal.
475+
target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
476+
[](TestRecursiveRewriteOp op) { return op.depth() == 0; });
477+
452478
// Handle a partial conversion.
453479
if (mode == ConversionMode::Partial) {
454480
(void)applyPartialConversion(getOperation(), target, patterns,

0 commit comments

Comments
 (0)