Skip to content

Commit 60fbd60

Browse files
authored
Revert "[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig (#83662)
This reverts commit 5f1319b. A FIR test is broken on Windows
1 parent 051e910 commit 60fbd60

File tree

3 files changed

+104
-116
lines changed

3 files changed

+104
-116
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ namespace mlir {
2424
// Forward declarations.
2525
class Attribute;
2626
class Block;
27-
struct ConversionConfig;
2827
class ConversionPatternRewriter;
2928
class MLIRContext;
3029
class Operation;
@@ -768,8 +767,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
768767
/// Conversion pattern rewriters must not be used outside of dialect
769768
/// conversions. They apply some IR rewrites in a delayed fashion and could
770769
/// bring the IR into an inconsistent state when used standalone.
771-
explicit ConversionPatternRewriter(MLIRContext *ctx,
772-
const ConversionConfig &config);
770+
explicit ConversionPatternRewriter(MLIRContext *ctx);
773771

774772
// Hide unsupported pattern rewriter API.
775773
using OpBuilder::setListener;
@@ -1069,30 +1067,6 @@ class PDLConversionConfig final {
10691067

10701068
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
10711069

1072-
//===----------------------------------------------------------------------===//
1073-
// ConversionConfig
1074-
//===----------------------------------------------------------------------===//
1075-
1076-
/// Dialect conversion configuration.
1077-
struct ConversionConfig {
1078-
/// An optional callback used to notify about match failure diagnostics during
1079-
/// the conversion. Diagnostics reported to this callback may only be
1080-
/// available in debug mode.
1081-
function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1082-
1083-
/// Partial conversion only. All operations that are found not to be
1084-
/// legalizable are placed in this set. (Note that if there is an op
1085-
/// explicitly marked as illegal, the conversion terminates and the set will
1086-
/// not necessarily be complete.)
1087-
DenseSet<Operation *> *unlegalizedOps = nullptr;
1088-
1089-
/// Analysis conversion only. All operations that are found to be legalizable
1090-
/// are placed in this set. Note that no actual rewrites are applied to the
1091-
/// IR during an analysis conversion and only pre-existing operations are
1092-
/// added to the set.
1093-
DenseSet<Operation *> *legalizableOps = nullptr;
1094-
};
1095-
10961070
//===----------------------------------------------------------------------===//
10971071
// Op Conversion Entry Points
10981072
//===----------------------------------------------------------------------===//
@@ -1106,44 +1080,52 @@ struct ConversionConfig {
11061080
/// Apply a partial conversion on the given operations and all nested
11071081
/// operations. This method converts as many operations to the target as
11081082
/// possible, ignoring operations that failed to legalize. This method only
1109-
/// returns failure if there ops explicitly marked as illegal.
1083+
/// returns failure if there ops explicitly marked as illegal. If an
1084+
/// `unconvertedOps` set is provided, all operations that are found not to be
1085+
/// legalizable to the given `target` are placed within that set. (Note that if
1086+
/// there is an op explicitly marked as illegal, the conversion terminates and
1087+
/// the `unconvertedOps` set will not necessarily be complete.)
11101088
LogicalResult
11111089
applyPartialConversion(ArrayRef<Operation *> ops,
11121090
const ConversionTarget &target,
11131091
const FrozenRewritePatternSet &patterns,
1114-
ConversionConfig config = ConversionConfig());
1092+
DenseSet<Operation *> *unconvertedOps = nullptr);
11151093
LogicalResult
11161094
applyPartialConversion(Operation *op, const ConversionTarget &target,
11171095
const FrozenRewritePatternSet &patterns,
1118-
ConversionConfig config = ConversionConfig());
1096+
DenseSet<Operation *> *unconvertedOps = nullptr);
11191097

11201098
/// Apply a complete conversion on the given operations, and all nested
11211099
/// operations. This method returns failure if the conversion of any operation
11221100
/// fails, or if there are unreachable blocks in any of the regions nested
11231101
/// within 'ops'.
11241102
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
11251103
const ConversionTarget &target,
1126-
const FrozenRewritePatternSet &patterns,
1127-
ConversionConfig config = ConversionConfig());
1104+
const FrozenRewritePatternSet &patterns);
11281105
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1129-
const FrozenRewritePatternSet &patterns,
1130-
ConversionConfig config = ConversionConfig());
1106+
const FrozenRewritePatternSet &patterns);
11311107

11321108
/// Apply an analysis conversion on the given operations, and all nested
11331109
/// operations. This method analyzes which operations would be successfully
11341110
/// converted to the target if a conversion was applied. All operations that
11351111
/// were found to be legalizable to the given 'target' are placed within the
1136-
/// provided 'config.legalizableOps' set; note that no actual rewrites are
1137-
/// applied to the operations on success. This method only returns failure if
1138-
/// there are unreachable blocks in any of the regions nested within 'ops'.
1139-
LogicalResult
1140-
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1141-
const FrozenRewritePatternSet &patterns,
1142-
ConversionConfig config = ConversionConfig());
1143-
LogicalResult
1144-
applyAnalysisConversion(Operation *op, ConversionTarget &target,
1145-
const FrozenRewritePatternSet &patterns,
1146-
ConversionConfig config = ConversionConfig());
1112+
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
1113+
/// operations on success and only pre-existing operations are added to the set.
1114+
/// This method only returns failure if there are unreachable blocks in any of
1115+
/// the regions nested within 'ops'. There's an additional argument
1116+
/// `notifyCallback` which is used for collecting match failure diagnostics
1117+
/// generated during the conversion. Diagnostics are only reported to this
1118+
/// callback may only be available in debug mode.
1119+
LogicalResult applyAnalysisConversion(
1120+
ArrayRef<Operation *> ops, ConversionTarget &target,
1121+
const FrozenRewritePatternSet &patterns,
1122+
DenseSet<Operation *> &convertedOps,
1123+
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
1124+
LogicalResult applyAnalysisConversion(
1125+
Operation *op, ConversionTarget &target,
1126+
const FrozenRewritePatternSet &patterns,
1127+
DenseSet<Operation *> &convertedOps,
1128+
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
11471129
} // namespace mlir
11481130

11491131
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,6 @@ class IRRewrite {
230230
/// Erase the given block (unless it was already erased).
231231
void eraseBlock(Block *block);
232232

233-
const ConversionConfig &getConfig() const;
234-
235233
const Kind kind;
236234
ConversionPatternRewriterImpl &rewriterImpl;
237235
};
@@ -734,9 +732,8 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
734732
namespace mlir {
735733
namespace detail {
736734
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
737-
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
738-
const ConversionConfig &config)
739-
: eraseRewriter(ctx), config(config) {}
735+
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
736+
: eraseRewriter(rewriter.getContext()) {}
740737

741738
//===--------------------------------------------------------------------===//
742739
// State Management
@@ -936,8 +933,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
936933
/// converting the arguments of blocks within that region.
937934
DenseMap<Region *, const TypeConverter *> regionToConverter;
938935

939-
/// Dialect conversion configuration.
940-
const ConversionConfig &config;
936+
/// This allows the user to collect the match failure message.
937+
function_ref<void(Diagnostic &)> notifyCallback;
938+
939+
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
940+
/// this is populated with ops found to be legalizable to the target.
941+
/// When mode == OpConversionMode::Partial, this is populated with ops found
942+
/// *not* to be legalizable to the target.
943+
DenseSet<Operation *> *trackedOps = nullptr;
941944

942945
#ifndef NDEBUG
943946
/// A set of operations that have pending updates. This tracking isn't
@@ -960,10 +963,6 @@ void IRRewrite::eraseBlock(Block *block) {
960963
rewriterImpl.eraseRewriter.eraseBlock(block);
961964
}
962965

963-
const ConversionConfig &IRRewrite::getConfig() const {
964-
return rewriterImpl.config;
965-
}
966-
967966
void BlockTypeConversionRewrite::commit() {
968967
// Process the remapping for each of the original arguments.
969968
for (auto [origArg, info] :
@@ -1081,8 +1080,8 @@ void ReplaceOperationRewrite::commit() {
10811080
if (Value newValue =
10821081
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
10831082
result.replaceAllUsesWith(newValue);
1084-
if (getConfig().unlegalizedOps)
1085-
getConfig().unlegalizedOps->erase(op);
1083+
if (rewriterImpl.trackedOps)
1084+
rewriterImpl.trackedOps->erase(op);
10861085
// Do not erase the operation yet. It may still be referenced in `mapping`.
10871086
op->getBlock()->getOperations().remove(op);
10881087
}
@@ -1505,19 +1504,18 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
15051504
Diagnostic diag(loc, DiagnosticSeverity::Remark);
15061505
reasonCallback(diag);
15071506
logger.startLine() << "** Failure : " << diag.str() << "\n";
1508-
if (config.notifyCallback)
1509-
config.notifyCallback(diag);
1507+
if (notifyCallback)
1508+
notifyCallback(diag);
15101509
});
15111510
}
15121511

15131512
//===----------------------------------------------------------------------===//
15141513
// ConversionPatternRewriter
15151514
//===----------------------------------------------------------------------===//
15161515

1517-
ConversionPatternRewriter::ConversionPatternRewriter(
1518-
MLIRContext *ctx, const ConversionConfig &config)
1516+
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
15191517
: PatternRewriter(ctx),
1520-
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
1518+
impl(new detail::ConversionPatternRewriterImpl(*this)) {
15211519
setListener(impl.get());
15221520
}
15231521

@@ -1986,12 +1984,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
19861984
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
19871985
LLVM_DEBUG({
19881986
logFailure(rewriterImpl.logger, "pattern failed to match");
1989-
if (rewriterImpl.config.notifyCallback) {
1987+
if (rewriterImpl.notifyCallback) {
19901988
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
19911989
diag << "Failed to apply pattern \"" << pattern.getDebugName()
19921990
<< "\" on op:\n"
19931991
<< *op;
1994-
rewriterImpl.config.notifyCallback(diag);
1992+
rewriterImpl.notifyCallback(diag);
19951993
}
19961994
});
19971995
rewriterImpl.resetState(curState);
@@ -2379,12 +2377,14 @@ namespace mlir {
23792377
struct OperationConverter {
23802378
explicit OperationConverter(const ConversionTarget &target,
23812379
const FrozenRewritePatternSet &patterns,
2382-
const ConversionConfig &config,
2383-
OpConversionMode mode)
2384-
: opLegalizer(target, patterns), config(config), mode(mode) {}
2380+
OpConversionMode mode,
2381+
DenseSet<Operation *> *trackedOps = nullptr)
2382+
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
23852383

23862384
/// Converts the given operations to the conversion target.
2387-
LogicalResult convertOperations(ArrayRef<Operation *> ops);
2385+
LogicalResult
2386+
convertOperations(ArrayRef<Operation *> ops,
2387+
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
23882388

23892389
private:
23902390
/// Converts an operation with the given rewriter.
@@ -2421,11 +2421,14 @@ struct OperationConverter {
24212421
/// The legalizer to use when converting operations.
24222422
OperationLegalizer opLegalizer;
24232423

2424-
/// Dialect conversion configuration.
2425-
ConversionConfig config;
2426-
24272424
/// The conversion mode to use when legalizing operations.
24282425
OpConversionMode mode;
2426+
2427+
/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2428+
/// this is populated with ops found to be legalizable to the target.
2429+
/// When mode == OpConversionMode::Partial, this is populated with ops found
2430+
/// *not* to be legalizable to the target.
2431+
DenseSet<Operation *> *trackedOps;
24292432
};
24302433
} // namespace mlir
24312434

@@ -2439,27 +2442,28 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
24392442
return op->emitError()
24402443
<< "failed to legalize operation '" << op->getName() << "'";
24412444
// Partial conversions allow conversions to fail iff the operation was not
2442-
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
2443-
// set, non-legalizable ops are added to that set.
2445+
// explicitly marked as illegal. If the user provided a nonlegalizableOps
2446+
// set, non-legalizable ops are included.
24442447
if (mode == OpConversionMode::Partial) {
24452448
if (opLegalizer.isIllegal(op))
24462449
return op->emitError()
24472450
<< "failed to legalize operation '" << op->getName()
24482451
<< "' that was explicitly marked illegal";
2449-
if (config.unlegalizedOps)
2450-
config.unlegalizedOps->insert(op);
2452+
if (trackedOps)
2453+
trackedOps->insert(op);
24512454
}
24522455
} else if (mode == OpConversionMode::Analysis) {
24532456
// Analysis conversions don't fail if any operations fail to legalize,
24542457
// they are only interested in the operations that were successfully
24552458
// legalized.
2456-
if (config.legalizableOps)
2457-
config.legalizableOps->insert(op);
2459+
trackedOps->insert(op);
24582460
}
24592461
return success();
24602462
}
24612463

2462-
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2464+
LogicalResult OperationConverter::convertOperations(
2465+
ArrayRef<Operation *> ops,
2466+
function_ref<void(Diagnostic &)> notifyCallback) {
24632467
if (ops.empty())
24642468
return success();
24652469
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2480,8 +2484,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24802484
}
24812485

24822486
// Convert each operation and discard rewrites on failure.
2483-
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
2487+
ConversionPatternRewriter rewriter(ops.front()->getContext());
24842488
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2489+
rewriterImpl.notifyCallback = notifyCallback;
2490+
rewriterImpl.trackedOps = trackedOps;
24852491

24862492
for (auto *op : toConvert)
24872493
if (failed(convert(rewriter, op)))
@@ -3468,51 +3474,57 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
34683474
//===----------------------------------------------------------------------===//
34693475
// Partial Conversion
34703476

3471-
LogicalResult mlir::applyPartialConversion(
3472-
ArrayRef<Operation *> ops, const ConversionTarget &target,
3473-
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3474-
OperationConverter opConverter(target, patterns, config,
3475-
OpConversionMode::Partial);
3477+
LogicalResult
3478+
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
3479+
const ConversionTarget &target,
3480+
const FrozenRewritePatternSet &patterns,
3481+
DenseSet<Operation *> *unconvertedOps) {
3482+
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
3483+
unconvertedOps);
34763484
return opConverter.convertOperations(ops);
34773485
}
34783486
LogicalResult
34793487
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
34803488
const FrozenRewritePatternSet &patterns,
3481-
ConversionConfig config) {
3482-
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
3489+
DenseSet<Operation *> *unconvertedOps) {
3490+
return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
3491+
unconvertedOps);
34833492
}
34843493

34853494
//===----------------------------------------------------------------------===//
34863495
// Full Conversion
34873496

3488-
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
3489-
const ConversionTarget &target,
3490-
const FrozenRewritePatternSet &patterns,
3491-
ConversionConfig config) {
3492-
OperationConverter opConverter(target, patterns, config,
3493-
OpConversionMode::Full);
3497+
LogicalResult
3498+
mlir::applyFullConversion(ArrayRef<Operation *> ops,
3499+
const ConversionTarget &target,
3500+
const FrozenRewritePatternSet &patterns) {
3501+
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
34943502
return opConverter.convertOperations(ops);
34953503
}
3496-
LogicalResult mlir::applyFullConversion(Operation *op,
3497-
const ConversionTarget &target,
3498-
const FrozenRewritePatternSet &patterns,
3499-
ConversionConfig config) {
3500-
return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
3504+
LogicalResult
3505+
mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
3506+
const FrozenRewritePatternSet &patterns) {
3507+
return applyFullConversion(llvm::ArrayRef(op), target, patterns);
35013508
}
35023509

35033510
//===----------------------------------------------------------------------===//
35043511
// Analysis Conversion
35053512

3506-
LogicalResult mlir::applyAnalysisConversion(
3507-
ArrayRef<Operation *> ops, ConversionTarget &target,
3508-
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3509-
OperationConverter opConverter(target, patterns, config,
3510-
OpConversionMode::Analysis);
3511-
return opConverter.convertOperations(ops);
3513+
LogicalResult
3514+
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
3515+
ConversionTarget &target,
3516+
const FrozenRewritePatternSet &patterns,
3517+
DenseSet<Operation *> &convertedOps,
3518+
function_ref<void(Diagnostic &)> notifyCallback) {
3519+
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
3520+
&convertedOps);
3521+
return opConverter.convertOperations(ops, notifyCallback);
35123522
}
35133523
LogicalResult
35143524
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
35153525
const FrozenRewritePatternSet &patterns,
3516-
ConversionConfig config) {
3517-
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
3526+
DenseSet<Operation *> &convertedOps,
3527+
function_ref<void(Diagnostic &)> notifyCallback) {
3528+
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
3529+
convertedOps, notifyCallback);
35183530
}

0 commit comments

Comments
 (0)