Skip to content

Commit 819e5f9

Browse files
[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API.
1 parent f3fe2f1 commit 819e5f9

File tree

3 files changed

+118
-100
lines changed

3 files changed

+118
-100
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace mlir {
2424
// Forward declarations.
2525
class Attribute;
2626
class Block;
27+
struct ConversionConfig;
2728
class ConversionPatternRewriter;
2829
class MLIRContext;
2930
class Operation;
@@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
770771
/// Conversion pattern rewriters must not be used outside of dialect
771772
/// conversions. They apply some IR rewrites in a delayed fashion and could
772773
/// bring the IR into an inconsistent state when used standalone.
773-
explicit ConversionPatternRewriter(MLIRContext *ctx);
774+
explicit ConversionPatternRewriter(MLIRContext *ctx,
775+
const ConversionConfig &config);
774776

775777
// Hide unsupported pattern rewriter API.
776778
using OpBuilder::setListener;
@@ -1070,6 +1072,30 @@ class PDLConversionConfig final {
10701072

10711073
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
10721074

1075+
//===----------------------------------------------------------------------===//
1076+
// ConversionConfig
1077+
//===----------------------------------------------------------------------===//
1078+
1079+
/// Dialect conversion configuration.
1080+
struct ConversionConfig {
1081+
/// An optional callback used to notify about match failure diagnostics during
1082+
/// the conversion. Diagnostics are only reported to this callback may only be
1083+
/// available in debug mode.
1084+
function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1085+
1086+
/// Partial conversion only. All operations that are found not to be
1087+
/// legalizable are placed in this set. (Note that if there is an op
1088+
/// explicitly marked as illegal, the conversion terminates and the set will
1089+
/// not necessarily be complete.)
1090+
DenseSet<Operation *> *unlegalizedOps = nullptr;
1091+
1092+
/// Analysis conversion only. All operations that are found to be legalizable
1093+
/// are placed in this set. Note that no actual rewrites are applied to the
1094+
/// IR during an analysis conversion and only pre-existing operations are
1095+
/// added to the set.
1096+
DenseSet<Operation *> *legalizableOps = nullptr;
1097+
};
1098+
10731099
//===----------------------------------------------------------------------===//
10741100
// Op Conversion Entry Points
10751101
//===----------------------------------------------------------------------===//
@@ -1083,51 +1109,44 @@ class PDLConversionConfig final {
10831109
/// Apply a partial conversion on the given operations and all nested
10841110
/// operations. This method converts as many operations to the target as
10851111
/// possible, ignoring operations that failed to legalize. This method only
1086-
/// returns failure if there ops explicitly marked as illegal. If an
1087-
/// `unconvertedOps` set is provided, all operations that are found not to be
1088-
/// legalizable to the given `target` are placed within that set. (Note that if
1089-
/// there is an op explicitly marked as illegal, the conversion terminates and
1090-
/// the `unconvertedOps` set will not necessarily be complete.)
1112+
/// returns failure if there ops explicitly marked as illegal.
10911113
LogicalResult
1092-
applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
1114+
applyPartialConversion(ArrayRef<Operation *> ops,
1115+
const ConversionTarget &target,
10931116
const FrozenRewritePatternSet &patterns,
1094-
DenseSet<Operation *> *unconvertedOps = nullptr);
1117+
ConversionConfig config = ConversionConfig());
10951118
LogicalResult
10961119
applyPartialConversion(Operation *op, const ConversionTarget &target,
10971120
const FrozenRewritePatternSet &patterns,
1098-
DenseSet<Operation *> *unconvertedOps = nullptr);
1121+
ConversionConfig config = ConversionConfig());
10991122

11001123
/// Apply a complete conversion on the given operations, and all nested
11011124
/// operations. This method returns failure if the conversion of any operation
11021125
/// fails, or if there are unreachable blocks in any of the regions nested
11031126
/// within 'ops'.
11041127
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
11051128
const ConversionTarget &target,
1106-
const FrozenRewritePatternSet &patterns);
1129+
const FrozenRewritePatternSet &patterns,
1130+
ConversionConfig config = ConversionConfig());
11071131
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1108-
const FrozenRewritePatternSet &patterns);
1132+
const FrozenRewritePatternSet &patterns,
1133+
ConversionConfig config = ConversionConfig());
11091134

11101135
/// Apply an analysis conversion on the given operations, and all nested
11111136
/// operations. This method analyzes which operations would be successfully
11121137
/// converted to the target if a conversion was applied. All operations that
11131138
/// were found to be legalizable to the given 'target' are placed within the
1114-
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
1115-
/// operations on success and only pre-existing operations are added to the set.
1116-
/// This method only returns failure if there are unreachable blocks in any of
1117-
/// the regions nested within 'ops'. There's an additional argument
1118-
/// `notifyCallback` which is used for collecting match failure diagnostics
1119-
/// generated during the conversion. Diagnostics are only reported to this
1120-
/// callback may only be available in debug mode.
1121-
LogicalResult applyAnalysisConversion(
1122-
ArrayRef<Operation *> ops, ConversionTarget &target,
1123-
const FrozenRewritePatternSet &patterns,
1124-
DenseSet<Operation *> &convertedOps,
1125-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1126-
LogicalResult applyAnalysisConversion(
1127-
Operation *op, ConversionTarget &target,
1128-
const FrozenRewritePatternSet &patterns,
1129-
DenseSet<Operation *> &convertedOps,
1130-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1139+
/// provided 'config.legalizableOps' set; note that no actual rewrites are
1140+
/// applied to the operations on success. This method only returns failure if
1141+
/// there are unreachable blocks in any of the regions nested within 'ops'.
1142+
LogicalResult
1143+
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1144+
const FrozenRewritePatternSet &patterns,
1145+
ConversionConfig config = ConversionConfig());
1146+
LogicalResult
1147+
applyAnalysisConversion(Operation *op, ConversionTarget &target,
1148+
const FrozenRewritePatternSet &patterns,
1149+
ConversionConfig config = ConversionConfig());
11311150
} // namespace mlir
11321151

11331152
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 61 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ class IRRewrite {
224224
/// Erase the given block (unless it was already erased).
225225
void eraseBlock(Block *block);
226226

227+
const ConversionConfig &getConfig() const;
228+
227229
const Kind kind;
228230
ConversionPatternRewriterImpl &rewriterImpl;
229231
};
@@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
723725
namespace mlir {
724726
namespace detail {
725727
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
726-
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
728+
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
729+
const ConversionConfig &config)
727730
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
728-
notifyCallback(nullptr) {}
731+
config(config) {}
729732

730733
//===--------------------------------------------------------------------===//
731734
// State Management
@@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
931934
/// converting the arguments of blocks within that region.
932935
DenseMap<Region *, const TypeConverter *> regionToConverter;
933936

934-
/// This allows the user to collect the match failure message.
935-
function_ref<void(Diagnostic &)> notifyCallback;
936-
937-
DenseSet<Operation *> *trackedOps = nullptr;
937+
/// Dialect conversion configuration.
938+
const ConversionConfig &config;
938939

939940
#ifndef NDEBUG
940941
/// A set of operations that have pending updates. This tracking isn't
@@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) {
957958
rewriterImpl.eraseRewriter.eraseBlock(block);
958959
}
959960

961+
const ConversionConfig &IRRewrite::getConfig() const {
962+
return rewriterImpl.config;
963+
}
964+
960965
void BlockTypeConversionRewrite::commit() {
961966
// Process the remapping for each of the original arguments.
962967
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
@@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() {
10741079
if (Value newValue =
10751080
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
10761081
result.replaceAllUsesWith(newValue);
1077-
if (rewriterImpl.trackedOps)
1078-
rewriterImpl.trackedOps->erase(op);
1082+
if (getConfig().unlegalizedOps)
1083+
getConfig().unlegalizedOps->erase(op);
10791084
// Do not erase the operation yet. It may still be referenced in `mapping`.
10801085
op->getBlock()->getOperations().remove(op);
10811086
}
@@ -1510,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15101515
Diagnostic diag(loc, DiagnosticSeverity::Remark);
15111516
reasonCallback(diag);
15121517
logger.startLine() << "** Failure : " << diag.str() << "\n";
1513-
if (notifyCallback)
1514-
notifyCallback(diag);
1518+
if (config.notifyCallback)
1519+
config.notifyCallback(diag);
15151520
});
15161521
}
15171522

15181523
//===----------------------------------------------------------------------===//
15191524
// ConversionPatternRewriter
15201525
//===----------------------------------------------------------------------===//
15211526

1522-
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1527+
ConversionPatternRewriter::ConversionPatternRewriter(
1528+
MLIRContext *ctx, const ConversionConfig &config)
15231529
: PatternRewriter(ctx),
1524-
impl(new detail::ConversionPatternRewriterImpl(*this)) {
1530+
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
15251531
setListener(impl.get());
15261532
}
15271533

@@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
19721978
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
19731979
LLVM_DEBUG({
19741980
logFailure(rewriterImpl.logger, "pattern failed to match");
1975-
if (rewriterImpl.notifyCallback) {
1981+
if (rewriterImpl.config.notifyCallback) {
19761982
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
19771983
diag << "Failed to apply pattern \"" << pattern.getDebugName()
19781984
<< "\" on op:\n"
19791985
<< *op;
1980-
rewriterImpl.notifyCallback(diag);
1986+
rewriterImpl.config.notifyCallback(diag);
19811987
}
19821988
});
19831989
rewriterImpl.resetState(curState);
@@ -2365,14 +2371,12 @@ namespace mlir {
23652371
struct OperationConverter {
23662372
explicit OperationConverter(const ConversionTarget &target,
23672373
const FrozenRewritePatternSet &patterns,
2368-
OpConversionMode mode,
2369-
DenseSet<Operation *> *trackedOps = nullptr)
2370-
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2374+
const ConversionConfig &config,
2375+
OpConversionMode mode)
2376+
: opLegalizer(target, patterns), config(config), mode(mode) {}
23712377

23722378
/// Converts the given operations to the conversion target.
2373-
LogicalResult
2374-
convertOperations(ArrayRef<Operation *> ops,
2375-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2379+
LogicalResult convertOperations(ArrayRef<Operation *> ops);
23762380

23772381
private:
23782382
/// Converts an operation with the given rewriter.
@@ -2409,14 +2413,11 @@ struct OperationConverter {
24092413
/// The legalizer to use when converting operations.
24102414
OperationLegalizer opLegalizer;
24112415

2416+
/// Dialect conversion configuration.
2417+
ConversionConfig config;
2418+
24122419
/// The conversion mode to use when legalizing operations.
24132420
OpConversionMode mode;
2414-
2415-
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2416-
/// this is populated with ops found to be legalizable to the target.
2417-
/// When mode == OpConversionMode::Partial, this is populated with ops found
2418-
/// *not* to be legalizable to the target.
2419-
DenseSet<Operation *> *trackedOps;
24202421
};
24212422
} // namespace mlir
24222423

@@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
24302431
return op->emitError()
24312432
<< "failed to legalize operation '" << op->getName() << "'";
24322433
// Partial conversions allow conversions to fail iff the operation was not
2433-
// explicitly marked as illegal. If the user provided a nonlegalizableOps
2434-
// set, non-legalizable ops are included.
2434+
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
2435+
// set, non-legalizable ops are added to that set.
24352436
if (mode == OpConversionMode::Partial) {
24362437
if (opLegalizer.isIllegal(op))
24372438
return op->emitError()
24382439
<< "failed to legalize operation '" << op->getName()
24392440
<< "' that was explicitly marked illegal";
2440-
if (trackedOps)
2441-
trackedOps->insert(op);
2441+
if (config.unlegalizedOps)
2442+
config.unlegalizedOps->insert(op);
24422443
}
24432444
} else if (mode == OpConversionMode::Analysis) {
24442445
// Analysis conversions don't fail if any operations fail to legalize,
24452446
// they are only interested in the operations that were successfully
24462447
// legalized.
2447-
trackedOps->insert(op);
2448+
if (config.legalizableOps)
2449+
config.legalizableOps->insert(op);
24482450
}
24492451
return success();
24502452
}
24512453

2452-
LogicalResult OperationConverter::convertOperations(
2453-
ArrayRef<Operation *> ops,
2454-
function_ref<void(Diagnostic &)> notifyCallback) {
2454+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24552455
if (ops.empty())
24562456
return success();
24572457
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations(
24722472
}
24732473

24742474
// Convert each operation and discard rewrites on failure.
2475-
ConversionPatternRewriter rewriter(ops.front()->getContext());
2475+
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
24762476
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2477-
rewriterImpl.notifyCallback = notifyCallback;
2478-
rewriterImpl.trackedOps = trackedOps;
24792477

24802478
for (auto *op : toConvert)
24812479
if (failed(convert(rewriter, op)))
@@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34613459
//===----------------------------------------------------------------------===//
34623460
// Partial Conversion
34633461

3464-
LogicalResult
3465-
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
3466-
const ConversionTarget &target,
3467-
const FrozenRewritePatternSet &patterns,
3468-
DenseSet<Operation *> *unconvertedOps) {
3469-
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3470-
unconvertedOps);
3462+
LogicalResult mlir::applyPartialConversion(
3463+
ArrayRef<Operation *> ops, const ConversionTarget &target,
3464+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3465+
OperationConverter opConverter(target, patterns, config,
3466+
OpConversionMode::Partial);
34713467
return opConverter.convertOperations(ops);
34723468
}
34733469
LogicalResult
34743470
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
34753471
const FrozenRewritePatternSet &patterns,
3476-
DenseSet<Operation *> *unconvertedOps) {
3477-
return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3478-
unconvertedOps);
3472+
ConversionConfig config) {
3473+
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
34793474
}
34803475

34813476
//===----------------------------------------------------------------------===//
34823477
// Full Conversion
34833478

3484-
LogicalResult
3485-
mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
3486-
const FrozenRewritePatternSet &patterns) {
3487-
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3479+
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3480+
const ConversionTarget &target,
3481+
const FrozenRewritePatternSet &patterns,
3482+
ConversionConfig config) {
3483+
OperationConverter opConverter(target, patterns, config,
3484+
OpConversionMode::Full);
34883485
return opConverter.convertOperations(ops);
34893486
}
3490-
LogicalResult
3491-
mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
3492-
const FrozenRewritePatternSet &patterns) {
3493-
return applyFullConversion(llvm::ArrayRef(op), target, patterns);
3487+
LogicalResult mlir::applyFullConversion(Operation *op,
3488+
const ConversionTarget &target,
3489+
const FrozenRewritePatternSet &patterns,
3490+
ConversionConfig config) {
3491+
return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
34943492
}
34953493

34963494
//===----------------------------------------------------------------------===//
34973495
// Analysis Conversion
34983496

3499-
LogicalResult
3500-
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
3501-
ConversionTarget &target,
3502-
const FrozenRewritePatternSet &patterns,
3503-
DenseSet<Operation *> &convertedOps,
3504-
function_ref<void(Diagnostic &)> notifyCallback) {
3505-
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3506-
&convertedOps);
3507-
return opConverter.convertOperations(ops, notifyCallback);
3497+
LogicalResult mlir::applyAnalysisConversion(
3498+
ArrayRef<Operation *> ops, ConversionTarget &target,
3499+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3500+
OperationConverter opConverter(target, patterns, config,
3501+
OpConversionMode::Analysis);
3502+
return opConverter.convertOperations(ops);
35083503
}
35093504
LogicalResult
35103505
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
35113506
const FrozenRewritePatternSet &patterns,
3512-
DenseSet<Operation *> &convertedOps,
3513-
function_ref<void(Diagnostic &)> notifyCallback) {
3514-
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3515-
convertedOps, notifyCallback);
3507+
ConversionConfig config) {
3508+
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
35163509
}

0 commit comments

Comments
 (0)