@@ -230,8 +230,6 @@ class IRRewrite {
230
230
// / Erase the given block (unless it was already erased).
231
231
void eraseBlock (Block *block);
232
232
233
- const ConversionConfig &getConfig () const ;
234
-
235
233
const Kind kind;
236
234
ConversionPatternRewriterImpl &rewriterImpl;
237
235
};
@@ -734,9 +732,8 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
734
732
namespace mlir {
735
733
namespace detail {
736
734
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
737
- explicit ConversionPatternRewriterImpl (MLIRContext *ctx,
738
- const ConversionConfig &config)
739
- : eraseRewriter(ctx), config(config) {}
735
+ explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter)
736
+ : eraseRewriter(rewriter.getContext()) {}
740
737
741
738
// ===--------------------------------------------------------------------===//
742
739
// State Management
@@ -936,8 +933,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
936
933
// / converting the arguments of blocks within that region.
937
934
DenseMap<Region *, const TypeConverter *> regionToConverter;
938
935
939
- // / Dialect conversion configuration.
940
- const ConversionConfig &config;
936
+ // / This allows the user to collect the match failure message.
937
+ function_ref<void (Diagnostic &)> notifyCallback;
938
+
939
+ // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
940
+ // / this is populated with ops found to be legalizable to the target.
941
+ // / When mode == OpConversionMode::Partial, this is populated with ops found
942
+ // / *not* to be legalizable to the target.
943
+ DenseSet<Operation *> *trackedOps = nullptr ;
941
944
942
945
#ifndef NDEBUG
943
946
// / A set of operations that have pending updates. This tracking isn't
@@ -960,10 +963,6 @@ void IRRewrite::eraseBlock(Block *block) {
960
963
rewriterImpl.eraseRewriter .eraseBlock (block);
961
964
}
962
965
963
- const ConversionConfig &IRRewrite::getConfig () const {
964
- return rewriterImpl.config ;
965
- }
966
-
967
966
void BlockTypeConversionRewrite::commit () {
968
967
// Process the remapping for each of the original arguments.
969
968
for (auto [origArg, info] :
@@ -1081,8 +1080,8 @@ void ReplaceOperationRewrite::commit() {
1081
1080
if (Value newValue =
1082
1081
rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
1083
1082
result.replaceAllUsesWith (newValue);
1084
- if (getConfig (). unlegalizedOps )
1085
- getConfig (). unlegalizedOps ->erase (op);
1083
+ if (rewriterImpl. trackedOps )
1084
+ rewriterImpl. trackedOps ->erase (op);
1086
1085
// Do not erase the operation yet. It may still be referenced in `mapping`.
1087
1086
op->getBlock ()->getOperations ().remove (op);
1088
1087
}
@@ -1505,19 +1504,18 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1505
1504
Diagnostic diag (loc, DiagnosticSeverity::Remark);
1506
1505
reasonCallback (diag);
1507
1506
logger.startLine () << " ** Failure : " << diag.str () << " \n " ;
1508
- if (config. notifyCallback )
1509
- config. notifyCallback (diag);
1507
+ if (notifyCallback)
1508
+ notifyCallback (diag);
1510
1509
});
1511
1510
}
1512
1511
1513
1512
// ===----------------------------------------------------------------------===//
1514
1513
// ConversionPatternRewriter
1515
1514
// ===----------------------------------------------------------------------===//
1516
1515
1517
- ConversionPatternRewriter::ConversionPatternRewriter (
1518
- MLIRContext *ctx, const ConversionConfig &config)
1516
+ ConversionPatternRewriter::ConversionPatternRewriter (MLIRContext *ctx)
1519
1517
: PatternRewriter(ctx),
1520
- impl(new detail::ConversionPatternRewriterImpl(ctx, config )) {
1518
+ impl(new detail::ConversionPatternRewriterImpl(* this )) {
1521
1519
setListener (impl.get ());
1522
1520
}
1523
1521
@@ -1986,12 +1984,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
1986
1984
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
1987
1985
LLVM_DEBUG ({
1988
1986
logFailure (rewriterImpl.logger , " pattern failed to match" );
1989
- if (rewriterImpl.config . notifyCallback ) {
1987
+ if (rewriterImpl.notifyCallback ) {
1990
1988
Diagnostic diag (op->getLoc (), DiagnosticSeverity::Remark);
1991
1989
diag << " Failed to apply pattern \" " << pattern.getDebugName ()
1992
1990
<< " \" on op:\n "
1993
1991
<< *op;
1994
- rewriterImpl.config . notifyCallback (diag);
1992
+ rewriterImpl.notifyCallback (diag);
1995
1993
}
1996
1994
});
1997
1995
rewriterImpl.resetState (curState);
@@ -2379,12 +2377,14 @@ namespace mlir {
2379
2377
struct OperationConverter {
2380
2378
explicit OperationConverter (const ConversionTarget &target,
2381
2379
const FrozenRewritePatternSet &patterns,
2382
- const ConversionConfig &config ,
2383
- OpConversionMode mode )
2384
- : opLegalizer(target, patterns), config(config ), mode(mode ) {}
2380
+ OpConversionMode mode ,
2381
+ DenseSet<Operation *> *trackedOps = nullptr )
2382
+ : opLegalizer(target, patterns), mode(mode ), trackedOps(trackedOps ) {}
2385
2383
2386
2384
// / Converts the given operations to the conversion target.
2387
- LogicalResult convertOperations (ArrayRef<Operation *> ops);
2385
+ LogicalResult
2386
+ convertOperations (ArrayRef<Operation *> ops,
2387
+ function_ref<void (Diagnostic &)> notifyCallback = nullptr );
2388
2388
2389
2389
private:
2390
2390
// / Converts an operation with the given rewriter.
@@ -2421,11 +2421,14 @@ struct OperationConverter {
2421
2421
// / The legalizer to use when converting operations.
2422
2422
OperationLegalizer opLegalizer;
2423
2423
2424
- // / Dialect conversion configuration.
2425
- ConversionConfig config;
2426
-
2427
2424
// / The conversion mode to use when legalizing operations.
2428
2425
OpConversionMode mode;
2426
+
2427
+ // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2428
+ // / this is populated with ops found to be legalizable to the target.
2429
+ // / When mode == OpConversionMode::Partial, this is populated with ops found
2430
+ // / *not* to be legalizable to the target.
2431
+ DenseSet<Operation *> *trackedOps;
2429
2432
};
2430
2433
} // namespace mlir
2431
2434
@@ -2439,27 +2442,28 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2439
2442
return op->emitError ()
2440
2443
<< " failed to legalize operation '" << op->getName () << " '" ;
2441
2444
// Partial conversions allow conversions to fail iff the operation was not
2442
- // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2443
- // set, non-legalizable ops are added to that set .
2445
+ // explicitly marked as illegal. If the user provided a nonlegalizableOps
2446
+ // set, non-legalizable ops are included .
2444
2447
if (mode == OpConversionMode::Partial) {
2445
2448
if (opLegalizer.isIllegal (op))
2446
2449
return op->emitError ()
2447
2450
<< " failed to legalize operation '" << op->getName ()
2448
2451
<< " ' that was explicitly marked illegal" ;
2449
- if (config. unlegalizedOps )
2450
- config. unlegalizedOps ->insert (op);
2452
+ if (trackedOps )
2453
+ trackedOps ->insert (op);
2451
2454
}
2452
2455
} else if (mode == OpConversionMode::Analysis) {
2453
2456
// Analysis conversions don't fail if any operations fail to legalize,
2454
2457
// they are only interested in the operations that were successfully
2455
2458
// legalized.
2456
- if (config.legalizableOps )
2457
- config.legalizableOps ->insert (op);
2459
+ trackedOps->insert (op);
2458
2460
}
2459
2461
return success ();
2460
2462
}
2461
2463
2462
- LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2464
+ LogicalResult OperationConverter::convertOperations (
2465
+ ArrayRef<Operation *> ops,
2466
+ function_ref<void (Diagnostic &)> notifyCallback) {
2463
2467
if (ops.empty ())
2464
2468
return success ();
2465
2469
const ConversionTarget &target = opLegalizer.getTarget ();
@@ -2480,8 +2484,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2480
2484
}
2481
2485
2482
2486
// Convert each operation and discard rewrites on failure.
2483
- ConversionPatternRewriter rewriter (ops.front ()->getContext (), config );
2487
+ ConversionPatternRewriter rewriter (ops.front ()->getContext ());
2484
2488
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2489
+ rewriterImpl.notifyCallback = notifyCallback;
2490
+ rewriterImpl.trackedOps = trackedOps;
2485
2491
2486
2492
for (auto *op : toConvert)
2487
2493
if (failed (convert (rewriter, op)))
@@ -3468,51 +3474,57 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3468
3474
// ===----------------------------------------------------------------------===//
3469
3475
// Partial Conversion
3470
3476
3471
- LogicalResult mlir::applyPartialConversion (
3472
- ArrayRef<Operation *> ops, const ConversionTarget &target,
3473
- const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3474
- OperationConverter opConverter (target, patterns, config,
3475
- OpConversionMode::Partial);
3477
+ LogicalResult
3478
+ mlir::applyPartialConversion (ArrayRef<Operation *> ops,
3479
+ const ConversionTarget &target,
3480
+ const FrozenRewritePatternSet &patterns,
3481
+ DenseSet<Operation *> *unconvertedOps) {
3482
+ OperationConverter opConverter (target, patterns, OpConversionMode::Partial,
3483
+ unconvertedOps);
3476
3484
return opConverter.convertOperations (ops);
3477
3485
}
3478
3486
LogicalResult
3479
3487
mlir::applyPartialConversion (Operation *op, const ConversionTarget &target,
3480
3488
const FrozenRewritePatternSet &patterns,
3481
- ConversionConfig config) {
3482
- return applyPartialConversion (llvm::ArrayRef (op), target, patterns, config);
3489
+ DenseSet<Operation *> *unconvertedOps) {
3490
+ return applyPartialConversion (llvm::ArrayRef (op), target, patterns,
3491
+ unconvertedOps);
3483
3492
}
3484
3493
3485
3494
// ===----------------------------------------------------------------------===//
3486
3495
// Full Conversion
3487
3496
3488
- LogicalResult mlir::applyFullConversion (ArrayRef<Operation *> ops,
3489
- const ConversionTarget &target,
3490
- const FrozenRewritePatternSet &patterns,
3491
- ConversionConfig config) {
3492
- OperationConverter opConverter (target, patterns, config,
3493
- OpConversionMode::Full);
3497
+ LogicalResult
3498
+ mlir::applyFullConversion (ArrayRef<Operation *> ops,
3499
+ const ConversionTarget &target,
3500
+ const FrozenRewritePatternSet &patterns) {
3501
+ OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3494
3502
return opConverter.convertOperations (ops);
3495
3503
}
3496
- LogicalResult mlir::applyFullConversion (Operation *op,
3497
- const ConversionTarget &target,
3498
- const FrozenRewritePatternSet &patterns,
3499
- ConversionConfig config) {
3500
- return applyFullConversion (llvm::ArrayRef (op), target, patterns, config);
3504
+ LogicalResult
3505
+ mlir::applyFullConversion (Operation *op, const ConversionTarget &target,
3506
+ const FrozenRewritePatternSet &patterns) {
3507
+ return applyFullConversion (llvm::ArrayRef (op), target, patterns);
3501
3508
}
3502
3509
3503
3510
// ===----------------------------------------------------------------------===//
3504
3511
// Analysis Conversion
3505
3512
3506
- LogicalResult mlir::applyAnalysisConversion (
3507
- ArrayRef<Operation *> ops, ConversionTarget &target,
3508
- const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3509
- OperationConverter opConverter (target, patterns, config,
3510
- OpConversionMode::Analysis);
3511
- return opConverter.convertOperations (ops);
3513
+ LogicalResult
3514
+ mlir::applyAnalysisConversion (ArrayRef<Operation *> ops,
3515
+ ConversionTarget &target,
3516
+ const FrozenRewritePatternSet &patterns,
3517
+ DenseSet<Operation *> &convertedOps,
3518
+ function_ref<void (Diagnostic &)> notifyCallback) {
3519
+ OperationConverter opConverter (target, patterns, OpConversionMode::Analysis,
3520
+ &convertedOps);
3521
+ return opConverter.convertOperations (ops, notifyCallback);
3512
3522
}
3513
3523
LogicalResult
3514
3524
mlir::applyAnalysisConversion (Operation *op, ConversionTarget &target,
3515
3525
const FrozenRewritePatternSet &patterns,
3516
- ConversionConfig config) {
3517
- return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns, config);
3526
+ DenseSet<Operation *> &convertedOps,
3527
+ function_ref<void (Diagnostic &)> notifyCallback) {
3528
+ return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns,
3529
+ convertedOps, notifyCallback);
3518
3530
}
0 commit comments