Skip to content

Commit 79f4143

Browse files
[mlir][Transforms] Dialect conversion: Move hasRewrite to expensive checks (#119848)
The dialect conversion has various checks that detect incorrect API usage in patterns. One of these checks turned out to be quite expensive (N*M complexity where N is the number of block rewrites and M is the total number of rewrites) in NVIDIA-internal workloads: Checking that a block is not converted multiple times. This check iterates over the stack of all rewrites, which can be large. We saw `hasRewrite` being called around 45000 times with an average rewrite stack size of 500000. This PR moves the check to `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. For consistency reasons, the other `hasRewrite`-based check is also moved there.
1 parent 473e251 commit 79f4143

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
714714
};
715715
} // namespace
716716

717+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
717718
/// Return "true" if there is an operation rewrite that matches the specified
718719
/// rewrite type and operation among the given rewrites.
719720
template <typename RewriteTy, typename R>
@@ -724,7 +725,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
724725
});
725726
}
726727

727-
#ifndef NDEBUG
728728
/// Return "true" if there is a block rewrite that matches the specified
729729
/// rewrite type and block among the given rewrites.
730730
template <typename RewriteTy, typename R>
@@ -734,7 +734,7 @@ static bool hasRewrite(R &&rewrites, Block *block) {
734734
return rewriteTy && rewriteTy->getBlock() == block;
735735
});
736736
}
737-
#endif // NDEBUG
737+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
738738

739739
//===----------------------------------------------------------------------===//
740740
// ConversionPatternRewriterImpl
@@ -1292,9 +1292,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12921292
ConversionPatternRewriter &rewriter, Block *block,
12931293
const TypeConverter *converter,
12941294
TypeConverter::SignatureConversion &signatureConversion) {
1295+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
12951296
// A block cannot be converted multiple times.
1296-
assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
1297-
"block was already converted");
1297+
if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
1298+
llvm::report_fatal_error("block was already converted");
1299+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1300+
12981301
OpBuilder::InsertionGuard g(rewriter);
12991302

13001303
// If no arguments are being changed or added, there is nothing to do.
@@ -2236,9 +2239,9 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22362239
ConversionPatternRewriter &rewriter,
22372240
RewriterState &curState) {
22382241
auto &impl = rewriter.getImpl();
2239-
2240-
#ifndef NDEBUG
22412242
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2243+
2244+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22422245
// Check that the root was either replaced or updated in place.
22432246
auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
22442247
auto replacedRoot = [&] {
@@ -2247,9 +2250,9 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22472250
auto updatedRootInPlace = [&] {
22482251
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
22492252
};
2250-
assert((replacedRoot() || updatedRootInPlace()) &&
2251-
"expected pattern to replace the root operation");
2252-
#endif // NDEBUG
2253+
if (!replacedRoot() && !updatedRootInPlace())
2254+
llvm::report_fatal_error("expected pattern to replace the root operation");
2255+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22532256

22542257
// Legalize each of the actions registered during application.
22552258
RewriterState newState = impl.getCurrentState();

0 commit comments

Comments
 (0)