@@ -228,6 +228,8 @@ class IRRewrite {
228
228
// / Erase the given block (unless it was already erased).
229
229
void eraseBlock (Block *block);
230
230
231
+ const ConversionConfig &getConfig () const ;
232
+
231
233
const Kind kind;
232
234
ConversionPatternRewriterImpl &rewriterImpl;
233
235
};
@@ -754,9 +756,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
754
756
namespace mlir {
755
757
namespace detail {
756
758
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
757
- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter)
759
+ explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter,
760
+ const ConversionConfig &config)
758
761
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
759
- notifyCallback( nullptr ) {}
762
+ config(config ) {}
760
763
761
764
// ===--------------------------------------------------------------------===//
762
765
// State Management
@@ -962,14 +965,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
962
965
// / converting the arguments of blocks within that region.
963
966
DenseMap<Region *, const TypeConverter *> regionToConverter;
964
967
965
- // / This allows the user to collect the match failure message.
966
- function_ref<void (Diagnostic &)> notifyCallback;
967
-
968
- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
969
- // / this is populated with ops found to be legalizable to the target.
970
- // / When mode == OpConversionMode::Partial, this is populated with ops found
971
- // / *not* to be legalizable to the target.
972
- DenseSet<Operation *> *trackedOps = nullptr ;
968
+ // / Dialect conversion configuration.
969
+ const ConversionConfig &config;
973
970
974
971
#ifndef NDEBUG
975
972
// / A set of operations that have pending updates. This tracking isn't
@@ -992,6 +989,10 @@ void IRRewrite::eraseBlock(Block *block) {
992
989
rewriterImpl.eraseRewriter .eraseBlock (block);
993
990
}
994
991
992
+ const ConversionConfig &IRRewrite::getConfig () const {
993
+ return rewriterImpl.config ;
994
+ }
995
+
995
996
void BlockTypeConversionRewrite::commit () {
996
997
// Process the remapping for each of the original arguments.
997
998
for (auto [origArg, info] :
@@ -1107,8 +1108,8 @@ void ReplaceOperationRewrite::commit() {
1107
1108
if (Value newValue =
1108
1109
rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
1109
1110
result.replaceAllUsesWith (newValue);
1110
- if (rewriterImpl. trackedOps )
1111
- rewriterImpl. trackedOps ->erase (op);
1111
+ if (getConfig (). unlegalizedOps )
1112
+ getConfig (). unlegalizedOps ->erase (op);
1112
1113
// Do not erase the operation yet. It may still be referenced in `mapping`.
1113
1114
op->getBlock ()->getOperations ().remove (op);
1114
1115
}
@@ -1543,18 +1544,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1543
1544
Diagnostic diag (loc, DiagnosticSeverity::Remark);
1544
1545
reasonCallback (diag);
1545
1546
logger.startLine () << " ** Failure : " << diag.str () << " \n " ;
1546
- if (notifyCallback)
1547
- notifyCallback (diag);
1547
+ if (config. notifyCallback )
1548
+ config. notifyCallback (diag);
1548
1549
});
1549
1550
}
1550
1551
1551
1552
// ===----------------------------------------------------------------------===//
1552
1553
// ConversionPatternRewriter
1553
1554
// ===----------------------------------------------------------------------===//
1554
1555
1555
- ConversionPatternRewriter::ConversionPatternRewriter (MLIRContext *ctx)
1556
+ ConversionPatternRewriter::ConversionPatternRewriter (
1557
+ MLIRContext *ctx, const ConversionConfig &config)
1556
1558
: PatternRewriter(ctx),
1557
- impl(new detail::ConversionPatternRewriterImpl(*this )) {
1559
+ impl(new detail::ConversionPatternRewriterImpl(*this , config )) {
1558
1560
setListener (impl.get ());
1559
1561
}
1560
1562
@@ -2005,12 +2007,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2005
2007
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2006
2008
LLVM_DEBUG ({
2007
2009
logFailure (rewriterImpl.logger , " pattern failed to match" );
2008
- if (rewriterImpl.notifyCallback ) {
2010
+ if (rewriterImpl.config . notifyCallback ) {
2009
2011
Diagnostic diag (op->getLoc (), DiagnosticSeverity::Remark);
2010
2012
diag << " Failed to apply pattern \" " << pattern.getDebugName ()
2011
2013
<< " \" on op:\n "
2012
2014
<< *op;
2013
- rewriterImpl.notifyCallback (diag);
2015
+ rewriterImpl.config . notifyCallback (diag);
2014
2016
}
2015
2017
});
2016
2018
rewriterImpl.resetState (curState);
@@ -2398,14 +2400,12 @@ namespace mlir {
2398
2400
struct OperationConverter {
2399
2401
explicit OperationConverter (const ConversionTarget &target,
2400
2402
const FrozenRewritePatternSet &patterns,
2401
- OpConversionMode mode ,
2402
- DenseSet<Operation *> *trackedOps = nullptr )
2403
- : opLegalizer(target, patterns), mode(mode ), trackedOps(trackedOps ) {}
2403
+ const ConversionConfig &config ,
2404
+ OpConversionMode mode )
2405
+ : opLegalizer(target, patterns), config(config ), mode(mode ) {}
2404
2406
2405
2407
// / Converts the given operations to the conversion target.
2406
- LogicalResult
2407
- convertOperations (ArrayRef<Operation *> ops,
2408
- function_ref<void (Diagnostic &)> notifyCallback = nullptr );
2408
+ LogicalResult convertOperations (ArrayRef<Operation *> ops);
2409
2409
2410
2410
private:
2411
2411
// / Converts an operation with the given rewriter.
@@ -2442,14 +2442,11 @@ struct OperationConverter {
2442
2442
// / The legalizer to use when converting operations.
2443
2443
OperationLegalizer opLegalizer;
2444
2444
2445
+ // / Dialect conversion configuration.
2446
+ ConversionConfig config;
2447
+
2445
2448
// / The conversion mode to use when legalizing operations.
2446
2449
OpConversionMode mode;
2447
-
2448
- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2449
- // / this is populated with ops found to be legalizable to the target.
2450
- // / When mode == OpConversionMode::Partial, this is populated with ops found
2451
- // / *not* to be legalizable to the target.
2452
- DenseSet<Operation *> *trackedOps;
2453
2450
};
2454
2451
} // namespace mlir
2455
2452
@@ -2463,28 +2460,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2463
2460
return op->emitError ()
2464
2461
<< " failed to legalize operation '" << op->getName () << " '" ;
2465
2462
// Partial conversions allow conversions to fail iff the operation was not
2466
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
2467
- // set, non-legalizable ops are included .
2463
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2464
+ // set, non-legalizable ops are added to that set .
2468
2465
if (mode == OpConversionMode::Partial) {
2469
2466
if (opLegalizer.isIllegal (op))
2470
2467
return op->emitError ()
2471
2468
<< " failed to legalize operation '" << op->getName ()
2472
2469
<< " ' that was explicitly marked illegal" ;
2473
- if (trackedOps )
2474
- trackedOps ->insert (op);
2470
+ if (config. unlegalizedOps )
2471
+ config. unlegalizedOps ->insert (op);
2475
2472
}
2476
2473
} else if (mode == OpConversionMode::Analysis) {
2477
2474
// Analysis conversions don't fail if any operations fail to legalize,
2478
2475
// they are only interested in the operations that were successfully
2479
2476
// legalized.
2480
- trackedOps->insert (op);
2477
+ if (config.legalizableOps )
2478
+ config.legalizableOps ->insert (op);
2481
2479
}
2482
2480
return success ();
2483
2481
}
2484
2482
2485
- LogicalResult OperationConverter::convertOperations (
2486
- ArrayRef<Operation *> ops,
2487
- function_ref<void (Diagnostic &)> notifyCallback) {
2483
+ LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2488
2484
if (ops.empty ())
2489
2485
return success ();
2490
2486
const ConversionTarget &target = opLegalizer.getTarget ();
@@ -2505,10 +2501,8 @@ LogicalResult OperationConverter::convertOperations(
2505
2501
}
2506
2502
2507
2503
// Convert each operation and discard rewrites on failure.
2508
- ConversionPatternRewriter rewriter (ops.front ()->getContext ());
2504
+ ConversionPatternRewriter rewriter (ops.front ()->getContext (), config );
2509
2505
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2510
- rewriterImpl.notifyCallback = notifyCallback;
2511
- rewriterImpl.trackedOps = trackedOps;
2512
2506
2513
2507
for (auto *op : toConvert)
2514
2508
if (failed (convert (rewriter, op)))
@@ -3495,57 +3489,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3495
3489
// ===----------------------------------------------------------------------===//
3496
3490
// Partial Conversion
3497
3491
3498
- LogicalResult
3499
- mlir::applyPartialConversion (ArrayRef<Operation *> ops,
3500
- const ConversionTarget &target,
3501
- const FrozenRewritePatternSet &patterns,
3502
- DenseSet<Operation *> *unconvertedOps) {
3503
- OperationConverter opConverter (target, patterns, OpConversionMode::Partial,
3504
- unconvertedOps);
3492
+ LogicalResult mlir::applyPartialConversion (
3493
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
3494
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3495
+ OperationConverter opConverter (target, patterns, config,
3496
+ OpConversionMode::Partial);
3505
3497
return opConverter.convertOperations (ops);
3506
3498
}
3507
3499
LogicalResult
3508
3500
mlir::applyPartialConversion (Operation *op, const ConversionTarget &target,
3509
3501
const FrozenRewritePatternSet &patterns,
3510
- DenseSet<Operation *> *unconvertedOps) {
3511
- return applyPartialConversion (llvm::ArrayRef (op), target, patterns,
3512
- unconvertedOps);
3502
+ ConversionConfig config) {
3503
+ return applyPartialConversion (llvm::ArrayRef (op), target, patterns, config);
3513
3504
}
3514
3505
3515
3506
// ===----------------------------------------------------------------------===//
3516
3507
// Full Conversion
3517
3508
3518
- LogicalResult
3519
- mlir::applyFullConversion (ArrayRef<Operation *> ops,
3520
- const ConversionTarget &target,
3521
- const FrozenRewritePatternSet &patterns) {
3522
- OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3509
+ LogicalResult mlir::applyFullConversion (ArrayRef<Operation *> ops,
3510
+ const ConversionTarget &target,
3511
+ const FrozenRewritePatternSet &patterns,
3512
+ ConversionConfig config) {
3513
+ OperationConverter opConverter (target, patterns, config,
3514
+ OpConversionMode::Full);
3523
3515
return opConverter.convertOperations (ops);
3524
3516
}
3525
- LogicalResult
3526
- mlir::applyFullConversion (Operation *op, const ConversionTarget &target,
3527
- const FrozenRewritePatternSet &patterns) {
3528
- return applyFullConversion (llvm::ArrayRef (op), target, patterns);
3517
+ LogicalResult mlir::applyFullConversion (Operation *op,
3518
+ const ConversionTarget &target,
3519
+ const FrozenRewritePatternSet &patterns,
3520
+ ConversionConfig config) {
3521
+ return applyFullConversion (llvm::ArrayRef (op), target, patterns, config);
3529
3522
}
3530
3523
3531
3524
// ===----------------------------------------------------------------------===//
3532
3525
// Analysis Conversion
3533
3526
3534
- LogicalResult
3535
- mlir::applyAnalysisConversion (ArrayRef<Operation *> ops,
3536
- ConversionTarget &target,
3537
- const FrozenRewritePatternSet &patterns,
3538
- DenseSet<Operation *> &convertedOps,
3539
- function_ref<void (Diagnostic &)> notifyCallback) {
3540
- OperationConverter opConverter (target, patterns, OpConversionMode::Analysis,
3541
- &convertedOps);
3542
- return opConverter.convertOperations (ops, notifyCallback);
3527
+ LogicalResult mlir::applyAnalysisConversion (
3528
+ ArrayRef<Operation *> ops, ConversionTarget &target,
3529
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3530
+ OperationConverter opConverter (target, patterns, config,
3531
+ OpConversionMode::Analysis);
3532
+ return opConverter.convertOperations (ops);
3543
3533
}
3544
3534
LogicalResult
3545
3535
mlir::applyAnalysisConversion (Operation *op, ConversionTarget &target,
3546
3536
const FrozenRewritePatternSet &patterns,
3547
- DenseSet<Operation *> &convertedOps,
3548
- function_ref<void (Diagnostic &)> notifyCallback) {
3549
- return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns,
3550
- convertedOps, notifyCallback);
3537
+ ConversionConfig config) {
3538
+ return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns, config);
3551
3539
}
0 commit comments