Skip to content

Commit a282109

Browse files
[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig (#83754)
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. This is a re-upload of #82250. The Windows build breakage was fixed in #83768. This reverts commit 60fbd60.
1 parent 354deba commit a282109

File tree

3 files changed

+116
-104
lines changed

3 files changed

+116
-104
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace mlir {
2424
// Forward declarations.
2525
class Attribute;
2626
class Block;
27+
struct ConversionConfig;
2728
class ConversionPatternRewriter;
2829
class MLIRContext;
2930
class Operation;
@@ -767,7 +768,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
767768
/// Conversion pattern rewriters must not be used outside of dialect
768769
/// conversions. They apply some IR rewrites in a delayed fashion and could
769770
/// bring the IR into an inconsistent state when used standalone.
770-
explicit ConversionPatternRewriter(MLIRContext *ctx);
771+
explicit ConversionPatternRewriter(MLIRContext *ctx,
772+
const ConversionConfig &config);
771773

772774
// Hide unsupported pattern rewriter API.
773775
using OpBuilder::setListener;
@@ -1067,6 +1069,30 @@ class PDLConversionConfig final {
10671069

10681070
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
10691071

1072+
//===----------------------------------------------------------------------===//
1073+
// ConversionConfig
1074+
//===----------------------------------------------------------------------===//
1075+
1076+
/// Dialect conversion configuration.
1077+
struct ConversionConfig {
1078+
/// An optional callback used to notify about match failure diagnostics during
1079+
/// the conversion. Diagnostics reported to this callback may only be
1080+
/// available in debug mode.
1081+
function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1082+
1083+
/// Partial conversion only. All operations that are found not to be
1084+
/// legalizable are placed in this set. (Note that if there is an op
1085+
/// explicitly marked as illegal, the conversion terminates and the set will
1086+
/// not necessarily be complete.)
1087+
DenseSet<Operation *> *unlegalizedOps = nullptr;
1088+
1089+
/// Analysis conversion only. All operations that are found to be legalizable
1090+
/// are placed in this set. Note that no actual rewrites are applied to the
1091+
/// IR during an analysis conversion and only pre-existing operations are
1092+
/// added to the set.
1093+
DenseSet<Operation *> *legalizableOps = nullptr;
1094+
};
1095+
10701096
//===----------------------------------------------------------------------===//
10711097
// Op Conversion Entry Points
10721098
//===----------------------------------------------------------------------===//
@@ -1080,52 +1106,44 @@ class PDLConversionConfig final {
10801106
/// Apply a partial conversion on the given operations and all nested
10811107
/// operations. This method converts as many operations to the target as
10821108
/// possible, ignoring operations that failed to legalize. This method only
1083-
/// returns failure if there ops explicitly marked as illegal. If an
1084-
/// `unconvertedOps` set is provided, all operations that are found not to be
1085-
/// legalizable to the given `target` are placed within that set. (Note that if
1086-
/// there is an op explicitly marked as illegal, the conversion terminates and
1087-
/// the `unconvertedOps` set will not necessarily be complete.)
1109+
/// returns failure if there ops explicitly marked as illegal.
10881110
LogicalResult
10891111
applyPartialConversion(ArrayRef<Operation *> ops,
10901112
const ConversionTarget &target,
10911113
const FrozenRewritePatternSet &patterns,
1092-
DenseSet<Operation *> *unconvertedOps = nullptr);
1114+
ConversionConfig config = ConversionConfig());
10931115
LogicalResult
10941116
applyPartialConversion(Operation *op, const ConversionTarget &target,
10951117
const FrozenRewritePatternSet &patterns,
1096-
DenseSet<Operation *> *unconvertedOps = nullptr);
1118+
ConversionConfig config = ConversionConfig());
10971119

10981120
/// Apply a complete conversion on the given operations, and all nested
10991121
/// operations. This method returns failure if the conversion of any operation
11001122
/// fails, or if there are unreachable blocks in any of the regions nested
11011123
/// within 'ops'.
11021124
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
11031125
const ConversionTarget &target,
1104-
const FrozenRewritePatternSet &patterns);
1126+
const FrozenRewritePatternSet &patterns,
1127+
ConversionConfig config = ConversionConfig());
11051128
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1106-
const FrozenRewritePatternSet &patterns);
1129+
const FrozenRewritePatternSet &patterns,
1130+
ConversionConfig config = ConversionConfig());
11071131

11081132
/// Apply an analysis conversion on the given operations, and all nested
11091133
/// operations. This method analyzes which operations would be successfully
11101134
/// converted to the target if a conversion was applied. All operations that
11111135
/// were found to be legalizable to the given 'target' are placed within the
1112-
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
1113-
/// operations on success and only pre-existing operations are added to the set.
1114-
/// This method only returns failure if there are unreachable blocks in any of
1115-
/// the regions nested within 'ops'. There's an additional argument
1116-
/// `notifyCallback` which is used for collecting match failure diagnostics
1117-
/// generated during the conversion. Diagnostics are only reported to this
1118-
/// callback may only be available in debug mode.
1119-
LogicalResult applyAnalysisConversion(
1120-
ArrayRef<Operation *> ops, ConversionTarget &target,
1121-
const FrozenRewritePatternSet &patterns,
1122-
DenseSet<Operation *> &convertedOps,
1123-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1124-
LogicalResult applyAnalysisConversion(
1125-
Operation *op, ConversionTarget &target,
1126-
const FrozenRewritePatternSet &patterns,
1127-
DenseSet<Operation *> &convertedOps,
1128-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1136+
/// provided 'config.legalizableOps' set; note that no actual rewrites are
1137+
/// applied to the operations on success. This method only returns failure if
1138+
/// there are unreachable blocks in any of the regions nested within 'ops'.
1139+
LogicalResult
1140+
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1141+
const FrozenRewritePatternSet &patterns,
1142+
ConversionConfig config = ConversionConfig());
1143+
LogicalResult
1144+
applyAnalysisConversion(Operation *op, ConversionTarget &target,
1145+
const FrozenRewritePatternSet &patterns,
1146+
ConversionConfig config = ConversionConfig());
11291147
} // namespace mlir
11301148

11311149
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ class IRRewrite {
230230
/// Erase the given block (unless it was already erased).
231231
void eraseBlock(Block *block);
232232

233+
const ConversionConfig &getConfig() const;
234+
233235
const Kind kind;
234236
ConversionPatternRewriterImpl &rewriterImpl;
235237
};
@@ -735,8 +737,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
735737
namespace mlir {
736738
namespace detail {
737739
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
738-
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
739-
: eraseRewriter(rewriter.getContext()) {}
740+
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
741+
const ConversionConfig &config)
742+
: eraseRewriter(ctx), config(config) {}
740743

741744
//===--------------------------------------------------------------------===//
742745
// State Management
@@ -936,14 +939,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
936939
/// converting the arguments of blocks within that region.
937940
DenseMap<Region *, const TypeConverter *> regionToConverter;
938941

939-
/// This allows the user to collect the match failure message.
940-
function_ref<void(Diagnostic &)> notifyCallback;
941-
942-
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
943-
/// this is populated with ops found to be legalizable to the target.
944-
/// When mode == OpConversionMode::Partial, this is populated with ops found
945-
/// *not* to be legalizable to the target.
946-
DenseSet<Operation *> *trackedOps = nullptr;
942+
/// Dialect conversion configuration.
943+
const ConversionConfig &config;
947944

948945
#ifndef NDEBUG
949946
/// A set of operations that have pending updates. This tracking isn't
@@ -966,6 +963,10 @@ void IRRewrite::eraseBlock(Block *block) {
966963
rewriterImpl.eraseRewriter.eraseBlock(block);
967964
}
968965

966+
const ConversionConfig &IRRewrite::getConfig() const {
967+
return rewriterImpl.config;
968+
}
969+
969970
void BlockTypeConversionRewrite::commit() {
970971
// Process the remapping for each of the original arguments.
971972
for (auto [origArg, info] :
@@ -1085,8 +1086,8 @@ void ReplaceOperationRewrite::commit() {
10851086
if (Value newValue =
10861087
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
10871088
result.replaceAllUsesWith(newValue);
1088-
if (rewriterImpl.trackedOps)
1089-
rewriterImpl.trackedOps->erase(op);
1089+
if (getConfig().unlegalizedOps)
1090+
getConfig().unlegalizedOps->erase(op);
10901091
// Do not erase the operation yet. It may still be referenced in `mapping`.
10911092
op->getBlock()->getOperations().remove(op);
10921093
}
@@ -1514,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15141515
Diagnostic diag(loc, DiagnosticSeverity::Remark);
15151516
reasonCallback(diag);
15161517
logger.startLine() << "** Failure : " << diag.str() << "\n";
1517-
if (notifyCallback)
1518-
notifyCallback(diag);
1518+
if (config.notifyCallback)
1519+
config.notifyCallback(diag);
15191520
});
15201521
}
15211522

15221523
//===----------------------------------------------------------------------===//
15231524
// ConversionPatternRewriter
15241525
//===----------------------------------------------------------------------===//
15251526

1526-
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1527+
ConversionPatternRewriter::ConversionPatternRewriter(
1528+
MLIRContext *ctx, const ConversionConfig &config)
15271529
: PatternRewriter(ctx),
1528-
impl(new detail::ConversionPatternRewriterImpl(*this)) {
1530+
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
15291531
setListener(impl.get());
15301532
}
15311533

@@ -1994,12 +1996,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
19941996
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
19951997
LLVM_DEBUG({
19961998
logFailure(rewriterImpl.logger, "pattern failed to match");
1997-
if (rewriterImpl.notifyCallback) {
1999+
if (rewriterImpl.config.notifyCallback) {
19982000
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
19992001
diag << "Failed to apply pattern \"" << pattern.getDebugName()
20002002
<< "\" on op:\n"
20012003
<< *op;
2002-
rewriterImpl.notifyCallback(diag);
2004+
rewriterImpl.config.notifyCallback(diag);
20032005
}
20042006
});
20052007
rewriterImpl.resetState(curState);
@@ -2387,14 +2389,12 @@ namespace mlir {
23872389
struct OperationConverter {
23882390
explicit OperationConverter(const ConversionTarget &target,
23892391
const FrozenRewritePatternSet &patterns,
2390-
OpConversionMode mode,
2391-
DenseSet<Operation *> *trackedOps = nullptr)
2392-
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
2392+
const ConversionConfig &config,
2393+
OpConversionMode mode)
2394+
: opLegalizer(target, patterns), config(config), mode(mode) {}
23932395

23942396
/// Converts the given operations to the conversion target.
2395-
LogicalResult
2396-
convertOperations(ArrayRef<Operation *> ops,
2397-
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
2397+
LogicalResult convertOperations(ArrayRef<Operation *> ops);
23982398

23992399
private:
24002400
/// Converts an operation with the given rewriter.
@@ -2431,14 +2431,11 @@ struct OperationConverter {
24312431
/// The legalizer to use when converting operations.
24322432
OperationLegalizer opLegalizer;
24332433

2434+
/// Dialect conversion configuration.
2435+
ConversionConfig config;
2436+
24342437
/// The conversion mode to use when legalizing operations.
24352438
OpConversionMode mode;
2436-
2437-
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2438-
/// this is populated with ops found to be legalizable to the target.
2439-
/// When mode == OpConversionMode::Partial, this is populated with ops found
2440-
/// *not* to be legalizable to the target.
2441-
DenseSet<Operation *> *trackedOps;
24422439
};
24432440
} // namespace mlir
24442441

@@ -2452,28 +2449,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
24522449
return op->emitError()
24532450
<< "failed to legalize operation '" << op->getName() << "'";
24542451
// Partial conversions allow conversions to fail iff the operation was not
2455-
// explicitly marked as illegal. If the user provided a nonlegalizableOps
2456-
// set, non-legalizable ops are included.
2452+
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
2453+
// set, non-legalizable ops are added to that set.
24572454
if (mode == OpConversionMode::Partial) {
24582455
if (opLegalizer.isIllegal(op))
24592456
return op->emitError()
24602457
<< "failed to legalize operation '" << op->getName()
24612458
<< "' that was explicitly marked illegal";
2462-
if (trackedOps)
2463-
trackedOps->insert(op);
2459+
if (config.unlegalizedOps)
2460+
config.unlegalizedOps->insert(op);
24642461
}
24652462
} else if (mode == OpConversionMode::Analysis) {
24662463
// Analysis conversions don't fail if any operations fail to legalize,
24672464
// they are only interested in the operations that were successfully
24682465
// legalized.
2469-
trackedOps->insert(op);
2466+
if (config.legalizableOps)
2467+
config.legalizableOps->insert(op);
24702468
}
24712469
return success();
24722470
}
24732471

2474-
LogicalResult OperationConverter::convertOperations(
2475-
ArrayRef<Operation *> ops,
2476-
function_ref<void(Diagnostic &)> notifyCallback) {
2472+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24772473
if (ops.empty())
24782474
return success();
24792475
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2494,10 +2490,8 @@ LogicalResult OperationConverter::convertOperations(
24942490
}
24952491

24962492
// Convert each operation and discard rewrites on failure.
2497-
ConversionPatternRewriter rewriter(ops.front()->getContext());
2493+
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
24982494
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2499-
rewriterImpl.notifyCallback = notifyCallback;
2500-
rewriterImpl.trackedOps = trackedOps;
25012495

25022496
for (auto *op : toConvert)
25032497
if (failed(convert(rewriter, op)))
@@ -3484,57 +3478,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34843478
//===----------------------------------------------------------------------===//
34853479
// Partial Conversion
34863480

3487-
LogicalResult
3488-
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
3489-
const ConversionTarget &target,
3490-
const FrozenRewritePatternSet &patterns,
3491-
DenseSet<Operation *> *unconvertedOps) {
3492-
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3493-
unconvertedOps);
3481+
LogicalResult mlir::applyPartialConversion(
3482+
ArrayRef<Operation *> ops, const ConversionTarget &target,
3483+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3484+
OperationConverter opConverter(target, patterns, config,
3485+
OpConversionMode::Partial);
34943486
return opConverter.convertOperations(ops);
34953487
}
34963488
LogicalResult
34973489
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
34983490
const FrozenRewritePatternSet &patterns,
3499-
DenseSet<Operation *> *unconvertedOps) {
3500-
return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3501-
unconvertedOps);
3491+
ConversionConfig config) {
3492+
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
35023493
}
35033494

35043495
//===----------------------------------------------------------------------===//
35053496
// Full Conversion
35063497

3507-
LogicalResult
3508-
mlir::applyFullConversion(ArrayRef<Operation *> ops,
3509-
const ConversionTarget &target,
3510-
const FrozenRewritePatternSet &patterns) {
3511-
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
3498+
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3499+
const ConversionTarget &target,
3500+
const FrozenRewritePatternSet &patterns,
3501+
ConversionConfig config) {
3502+
OperationConverter opConverter(target, patterns, config,
3503+
OpConversionMode::Full);
35123504
return opConverter.convertOperations(ops);
35133505
}
3514-
LogicalResult
3515-
mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
3516-
const FrozenRewritePatternSet &patterns) {
3517-
return applyFullConversion(llvm::ArrayRef(op), target, patterns);
3506+
LogicalResult mlir::applyFullConversion(Operation *op,
3507+
const ConversionTarget &target,
3508+
const FrozenRewritePatternSet &patterns,
3509+
ConversionConfig config) {
3510+
return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
35183511
}
35193512

35203513
//===----------------------------------------------------------------------===//
35213514
// Analysis Conversion
35223515

3523-
LogicalResult
3524-
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
3525-
ConversionTarget &target,
3526-
const FrozenRewritePatternSet &patterns,
3527-
DenseSet<Operation *> &convertedOps,
3528-
function_ref<void(Diagnostic &)> notifyCallback) {
3529-
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3530-
&convertedOps);
3531-
return opConverter.convertOperations(ops, notifyCallback);
3516+
LogicalResult mlir::applyAnalysisConversion(
3517+
ArrayRef<Operation *> ops, ConversionTarget &target,
3518+
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3519+
OperationConverter opConverter(target, patterns, config,
3520+
OpConversionMode::Analysis);
3521+
return opConverter.convertOperations(ops);
35323522
}
35333523
LogicalResult
35343524
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
35353525
const FrozenRewritePatternSet &patterns,
3536-
DenseSet<Operation *> &convertedOps,
3537-
function_ref<void(Diagnostic &)> notifyCallback) {
3538-
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3539-
convertedOps, notifyCallback);
3526+
ConversionConfig config) {
3527+
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
35403528
}

0 commit comments

Comments
 (0)