Skip to content

Commit 577b5eb

Browse files
[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API.
1 parent a622b21 commit 577b5eb

File tree

3 files changed

+118
-105
lines changed

3 files changed

+118
-105
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace mlir {
2424
// Forward declarations.
2525
class Attribute;
2626
class Block;
27+
struct ConversionConfig;
2728
class ConversionPatternRewriter;
2829
class MLIRContext;
2930
class Operation;
@@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
770771
/// Conversion pattern rewriters must not be used outside of dialect
771772
/// conversions. They apply some IR rewrites in a delayed fashion and could
772773
/// bring the IR into an inconsistent state when used standalone.
773-
explicit ConversionPatternRewriter(MLIRContext *ctx);
774+
explicit ConversionPatternRewriter(MLIRContext *ctx,
775+
const ConversionConfig &config);
774776

775777
// Hide unsupported pattern rewriter API.
776778
using OpBuilder::setListener;
@@ -1070,6 +1072,30 @@ class PDLConversionConfig final {
10701072

10711073
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
10721074

1075+
//===----------------------------------------------------------------------===//
1076+
// ConversionConfig
1077+
//===----------------------------------------------------------------------===//
1078+
1079+
/// Dialect conversion configuration.
1080+
struct ConversionConfig {
1081+
/// An optional callback used to notify about match failure diagnostics during
1082+
/// the conversion. Diagnostics are only reported to this callback may only be
1083+
/// available in debug mode.
1084+
function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1085+
1086+
/// Partial conversion only. All operations that are found not to be
1087+
/// legalizable are placed in this set. (Note that if there is an op
1088+
/// explicitly marked as illegal, the conversion terminates and the set will
1089+
/// not necessarily be complete.)
1090+
DenseSet<Operation *> *unlegalizedOps = nullptr;
1091+
1092+
/// Analysis conversion only. All operations that are found to be legalizable
1093+
/// are placed in this set. Note that no actual rewrites are applied to the
1094+
/// IR during an analysis conversion and only pre-existing operations are
1095+
/// added to the set.
1096+
DenseSet<Operation *> *legalizableOps = nullptr;
1097+
};
1098+
10731099
//===----------------------------------------------------------------------===//
10741100
// Op Conversion Entry Points
10751101
//===----------------------------------------------------------------------===//
@@ -1083,51 +1109,44 @@ class PDLConversionConfig final {
10831109
/// Apply a partial conversion on the given operations and all nested
10841110
/// operations. This method converts as many operations to the target as
10851111
/// possible, ignoring operations that failed to legalize. This method only
1086-
/// returns failure if there ops explicitly marked as illegal. If an
1087-
/// `unconvertedOps` set is provided, all operations that are found not to be
1088-
/// legalizable to the given `target` are placed within that set. (Note that if
1089-
/// there is an op explicitly marked as illegal, the conversion terminates and
1090-
/// the `unconvertedOps` set will not necessarily be complete.)
1112+
/// returns failure if there ops explicitly marked as illegal.
10911113
LogicalResult
1092-
applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
1114+
applyPartialConversion(ArrayRef<Operation *> ops,
1115+
const ConversionTarget &target,
10931116
const FrozenRewritePatternSet &patterns,
1094-
DenseSet<Operation *> *unconvertedOps = nullptr);
1117+
ConversionConfig config = ConversionConfig());
10951118
LogicalResult
10961119
applyPartialConversion(Operation *op, const ConversionTarget &target,
10971120
const FrozenRewritePatternSet &patterns,
1098-
DenseSet<Operation *> *unconvertedOps = nullptr);
1121+
ConversionConfig config = ConversionConfig());
10991122

11001123
/// Apply a complete conversion on the given operations, and all nested
11011124
/// operations. This method returns failure if the conversion of any operation
11021125
/// fails, or if there are unreachable blocks in any of the regions nested
11031126
/// within 'ops'.
11041127
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
11051128
const ConversionTarget &target,
1106-
const FrozenRewritePatternSet &patterns);
1129+
const FrozenRewritePatternSet &patterns,
1130+
ConversionConfig config = ConversionConfig());
11071131
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1108-
const FrozenRewritePatternSet &patterns);
1132+
const FrozenRewritePatternSet &patterns,
1133+
ConversionConfig config = ConversionConfig());
11091134

11101135
/// Apply an analysis conversion on the given operations, and all nested
11111136
/// operations. This method analyzes which operations would be successfully
11121137
/// converted to the target if a conversion was applied. All operations that
11131138
/// were found to be legalizable to the given 'target' are placed within the
1114-
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
1115-
/// operations on success and only pre-existing operations are added to the set.
1116-
/// This method only returns failure if there are unreachable blocks in any of
1117-
/// the regions nested within 'ops'. There's an additional argument
1118-
/// `notifyCallback` which is used for collecting match failure diagnostics
1119-
/// generated during the conversion. Diagnostics are only reported to this
1120-
/// callback may only be available in debug mode.
1121-
LogicalResult applyAnalysisConversion(
1122-
ArrayRef<Operation *> ops, ConversionTarget &target,
1123-
const FrozenRewritePatternSet &patterns,
1124-
DenseSet<Operation *> &convertedOps,
1125-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1126-
LogicalResult applyAnalysisConversion(
1127-
Operation *op, ConversionTarget &target,
1128-
const FrozenRewritePatternSet &patterns,
1129-
DenseSet<Operation *> &convertedOps,
1130-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1139+
/// provided 'config.legalizableOps' set; note that no actual rewrites are
1140+
/// applied to the operations on success. This method only returns failure if
1141+
/// there are unreachable blocks in any of the regions nested within 'ops'.
1142+
LogicalResult
1143+
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1144+
const FrozenRewritePatternSet &patterns,
1145+
ConversionConfig config = ConversionConfig());
1146+
LogicalResult
1147+
applyAnalysisConversion(Operation *op, ConversionTarget &target,
1148+
const FrozenRewritePatternSet &patterns,
1149+
ConversionConfig config = ConversionConfig());
11311150
} // namespace mlir
11321151

11331152
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ class IRRewrite {
228228
/// Erase the given block (unless it was already erased).
229229
void eraseBlock(Block *block);
230230

231+
const ConversionConfig &getConfig() const;
232+
231233
const Kind kind;
232234
ConversionPatternRewriterImpl &rewriterImpl;
233235
};
@@ -754,9 +756,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
754756
namespace mlir {
755757
namespace detail {
756758
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
757-
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
759+
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
760+
const ConversionConfig &config)
758761
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
759-
notifyCallback(nullptr) {}
762+
config(config) {}
760763

761764
//===--------------------------------------------------------------------===//
762765
// State Management
@@ -962,14 +965,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
962965
/// converting the arguments of blocks within that region.
963966
DenseMap<Region *, const TypeConverter *> regionToConverter;
964967

965-
/// This allows the user to collect the match failure message.
966-
function_ref<void(Diagnostic &)> notifyCallback;
967-
968-
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
969-
/// this is populated with ops found to be legalizable to the target.
970-
/// When mode == OpConversionMode::Partial, this is populated with ops found
971-
/// *not* to be legalizable to the target.
972-
DenseSet<Operation *> *trackedOps = nullptr;
968+
/// Dialect conversion configuration.
969+
const ConversionConfig &config;
973970

974971
#ifndef NDEBUG
975972
/// A set of operations that have pending updates. This tracking isn't
@@ -992,6 +989,10 @@ void IRRewrite::eraseBlock(Block *block) {
992989
rewriterImpl.eraseRewriter.eraseBlock(block);
993990
}
994991

992+
const ConversionConfig &IRRewrite::getConfig() const {
993+
return rewriterImpl.config;
994+
}
995+
995996
void BlockTypeConversionRewrite::commit() {
996997
// Process the remapping for each of the original arguments.
997998
for (auto [origArg, info] :
@@ -1107,8 +1108,8 @@ void ReplaceOperationRewrite::commit() {
11071108
if (Value newValue =
11081109
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
11091110
result.replaceAllUsesWith(newValue);
1110-
if (rewriterImpl.trackedOps)
1111-
rewriterImpl.trackedOps->erase(op);
1111+
if (getConfig().unlegalizedOps)
1112+
getConfig().unlegalizedOps->erase(op);
11121113
// Do not erase the operation yet. It may still be referenced in `mapping`.
11131114
op->getBlock()->getOperations().remove(op);
11141115
}
@@ -1543,18 +1544,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15431544
Diagnostic diag(loc, DiagnosticSeverity::Remark);
15441545
reasonCallback(diag);
15451546
logger.startLine() << "** Failure : " << diag.str() << "\n";
1546-
if (notifyCallback)
1547-
notifyCallback(diag);
1547+
if (config.notifyCallback)
1548+
config.notifyCallback(diag);
15481549
});
15491550
}
15501551

15511552
//===----------------------------------------------------------------------===//
15521553
// ConversionPatternRewriter
15531554
//===----------------------------------------------------------------------===//
15541555

1555-
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1556+
ConversionPatternRewriter::ConversionPatternRewriter(
1557+
MLIRContext *ctx, const ConversionConfig &config)
15561558
: PatternRewriter(ctx),
1557-
impl(new detail::ConversionPatternRewriterImpl(*this)) {
1559+
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
15581560
setListener(impl.get());
15591561
}
15601562

@@ -2005,12 +2007,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
20052007
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
20062008
LLVM_DEBUG({
20072009
logFailure(rewriterImpl.logger, "pattern failed to match");
2008-
if (rewriterImpl.notifyCallback) {
2010+
if (rewriterImpl.config.notifyCallback) {
20092011
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
20102012
diag << "Failed to apply pattern \"" << pattern.getDebugName()
20112013
<< "\" on op:\n"
20122014
<< *op;
2013-
rewriterImpl.notifyCallback(diag);
2015+
rewriterImpl.config.notifyCallback(diag);
20142016
}
20152017
});
20162018
rewriterImpl.resetState(curState);
@@ -2398,14 +2400,12 @@ namespace mlir {
23982400
struct OperationConverter {
23992401
explicit OperationConverter(const ConversionTarget &target,
24002402
const FrozenRewritePatternSet &patterns,
2401-
OpConversionMode mode,
2402-
DenseSet<Operation *> *trackedOps = nullptr)
2403-
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2403+
const ConversionConfig &config,
2404+
OpConversionMode mode)
2405+
: opLegalizer(target, patterns), config(config), mode(mode) {}
24042406

24052407
/// Converts the given operations to the conversion target.
2406-
LogicalResult
2407-
convertOperations(ArrayRef<Operation *> ops,
2408-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2408+
LogicalResult convertOperations(ArrayRef<Operation *> ops);
24092409

24102410
private:
24112411
/// Converts an operation with the given rewriter.
@@ -2442,14 +2442,11 @@ struct OperationConverter {
24422442
/// The legalizer to use when converting operations.
24432443
OperationLegalizer opLegalizer;
24442444

2445+
/// Dialect conversion configuration.
2446+
ConversionConfig config;
2447+
24452448
/// The conversion mode to use when legalizing operations.
24462449
OpConversionMode mode;
2447-
2448-
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2449-
/// this is populated with ops found to be legalizable to the target.
2450-
/// When mode == OpConversionMode::Partial, this is populated with ops found
2451-
/// *not* to be legalizable to the target.
2452-
DenseSet<Operation *> *trackedOps;
24532450
};
24542451
} // namespace mlir
24552452

@@ -2463,28 +2460,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
24632460
return op->emitError()
24642461
<< "failed to legalize operation '" << op->getName() << "'";
24652462
// Partial conversions allow conversions to fail iff the operation was not
2466-
// explicitly marked as illegal. If the user provided a nonlegalizableOps
2467-
// set, non-legalizable ops are included.
2463+
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
2464+
// set, non-legalizable ops are added to that set.
24682465
if (mode == OpConversionMode::Partial) {
24692466
if (opLegalizer.isIllegal(op))
24702467
return op->emitError()
24712468
<< "failed to legalize operation '" << op->getName()
24722469
<< "' that was explicitly marked illegal";
2473-
if (trackedOps)
2474-
trackedOps->insert(op);
2470+
if (config.unlegalizedOps)
2471+
config.unlegalizedOps->insert(op);
24752472
}
24762473
} else if (mode == OpConversionMode::Analysis) {
24772474
// Analysis conversions don't fail if any operations fail to legalize,
24782475
// they are only interested in the operations that were successfully
24792476
// legalized.
2480-
trackedOps->insert(op);
2477+
if (config.legalizableOps)
2478+
config.legalizableOps->insert(op);
24812479
}
24822480
return success();
24832481
}
24842482

2485-
LogicalResult OperationConverter::convertOperations(
2486-
ArrayRef<Operation *> ops,
2487-
function_ref<void(Diagnostic &)> notifyCallback) {
2483+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24882484
if (ops.empty())
24892485
return success();
24902486
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2505,10 +2501,8 @@ LogicalResult OperationConverter::convertOperations(
25052501
}
25062502

25072503
// Convert each operation and discard rewrites on failure.
2508-
ConversionPatternRewriter rewriter(ops.front()->getContext());
2504+
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
25092505
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2510-
rewriterImpl.notifyCallback = notifyCallback;
2511-
rewriterImpl.trackedOps = trackedOps;
25122506

25132507
for (auto *op : toConvert)
25142508
if (failed(convert(rewriter, op)))
@@ -3495,57 +3489,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34953489
//===----------------------------------------------------------------------===//
34963490
// Partial Conversion
34973491

3498-
LogicalResult
3499-
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
3500-
const ConversionTarget &target,
3501-
const FrozenRewritePatternSet &patterns,
3502-
DenseSet<Operation *> *unconvertedOps) {
3503-
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3504-
unconvertedOps);
3492+
LogicalResult mlir::applyPartialConversion(
3493+
ArrayRef<Operation *> ops, const ConversionTarget &target,
3494+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3495+
OperationConverter opConverter(target, patterns, config,
3496+
OpConversionMode::Partial);
35053497
return opConverter.convertOperations(ops);
35063498
}
35073499
LogicalResult
35083500
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
35093501
const FrozenRewritePatternSet &patterns,
3510-
DenseSet<Operation *> *unconvertedOps) {
3511-
return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3512-
unconvertedOps);
3502+
ConversionConfig config) {
3503+
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
35133504
}
35143505

35153506
//===----------------------------------------------------------------------===//
35163507
// Full Conversion
35173508

3518-
LogicalResult
3519-
mlir::applyFullConversion(ArrayRef<Operation *> ops,
3520-
const ConversionTarget &target,
3521-
const FrozenRewritePatternSet &patterns) {
3522-
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3509+
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3510+
const ConversionTarget &target,
3511+
const FrozenRewritePatternSet &patterns,
3512+
ConversionConfig config) {
3513+
OperationConverter opConverter(target, patterns, config,
3514+
OpConversionMode::Full);
35233515
return opConverter.convertOperations(ops);
35243516
}
3525-
LogicalResult
3526-
mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
3527-
const FrozenRewritePatternSet &patterns) {
3528-
return applyFullConversion(llvm::ArrayRef(op), target, patterns);
3517+
LogicalResult mlir::applyFullConversion(Operation *op,
3518+
const ConversionTarget &target,
3519+
const FrozenRewritePatternSet &patterns,
3520+
ConversionConfig config) {
3521+
return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
35293522
}
35303523

35313524
//===----------------------------------------------------------------------===//
35323525
// Analysis Conversion
35333526

3534-
LogicalResult
3535-
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
3536-
ConversionTarget &target,
3537-
const FrozenRewritePatternSet &patterns,
3538-
DenseSet<Operation *> &convertedOps,
3539-
function_ref<void(Diagnostic &)> notifyCallback) {
3540-
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3541-
&convertedOps);
3542-
return opConverter.convertOperations(ops, notifyCallback);
3527+
LogicalResult mlir::applyAnalysisConversion(
3528+
ArrayRef<Operation *> ops, ConversionTarget &target,
3529+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3530+
OperationConverter opConverter(target, patterns, config,
3531+
OpConversionMode::Analysis);
3532+
return opConverter.convertOperations(ops);
35433533
}
35443534
LogicalResult
35453535
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
35463536
const FrozenRewritePatternSet &patterns,
3547-
DenseSet<Operation *> &convertedOps,
3548-
function_ref<void(Diagnostic &)> notifyCallback) {
3549-
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3550-
convertedOps, notifyCallback);
3537+
ConversionConfig config) {
3538+
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
35513539
}

0 commit comments

Comments
 (0)