@@ -224,6 +224,8 @@ class IRRewrite {
224
224
// / Erase the given block (unless it was already erased).
225
225
void eraseBlock (Block *block);
226
226
227
+ const ConversionConfig &getConfig () const ;
228
+
227
229
const Kind kind;
228
230
ConversionPatternRewriterImpl &rewriterImpl;
229
231
};
@@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
723
725
namespace mlir {
724
726
namespace detail {
725
727
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
726
- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter)
728
+ explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter,
729
+ const ConversionConfig &config)
727
730
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
728
- notifyCallback( nullptr ) {}
731
+ config(config ) {}
729
732
730
733
// ===--------------------------------------------------------------------===//
731
734
// State Management
@@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
931
934
// / converting the arguments of blocks within that region.
932
935
DenseMap<Region *, const TypeConverter *> regionToConverter;
933
936
934
- // / This allows the user to collect the match failure message.
935
- function_ref<void (Diagnostic &)> notifyCallback;
936
-
937
- DenseSet<Operation *> *trackedOps = nullptr ;
937
+ // / Dialect conversion configuration.
938
+ const ConversionConfig &config;
938
939
939
940
#ifndef NDEBUG
940
941
// / A set of operations that have pending updates. This tracking isn't
@@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) {
957
958
rewriterImpl.eraseRewriter .eraseBlock (block);
958
959
}
959
960
961
+ const ConversionConfig &IRRewrite::getConfig () const {
962
+ return rewriterImpl.config ;
963
+ }
964
+
960
965
void BlockTypeConversionRewrite::commit () {
961
966
// Process the remapping for each of the original arguments.
962
967
for (unsigned i = 0 , e = origBlock->getNumArguments (); i != e; ++i) {
@@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() {
1074
1079
if (Value newValue =
1075
1080
rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
1076
1081
result.replaceAllUsesWith (newValue);
1077
- if (rewriterImpl. trackedOps )
1078
- rewriterImpl. trackedOps ->erase (op);
1082
+ if (getConfig (). unlegalizedOps )
1083
+ getConfig (). unlegalizedOps ->erase (op);
1079
1084
// Do not erase the operation yet. It may still be referenced in `mapping`.
1080
1085
op->getBlock ()->getOperations ().remove (op);
1081
1086
}
@@ -1510,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1510
1515
Diagnostic diag (loc, DiagnosticSeverity::Remark);
1511
1516
reasonCallback (diag);
1512
1517
logger.startLine () << " ** Failure : " << diag.str () << " \n " ;
1513
- if (notifyCallback)
1514
- notifyCallback (diag);
1518
+ if (config. notifyCallback )
1519
+ config. notifyCallback (diag);
1515
1520
});
1516
1521
}
1517
1522
1518
1523
// ===----------------------------------------------------------------------===//
1519
1524
// ConversionPatternRewriter
1520
1525
// ===----------------------------------------------------------------------===//
1521
1526
1522
- ConversionPatternRewriter::ConversionPatternRewriter (MLIRContext *ctx)
1527
+ ConversionPatternRewriter::ConversionPatternRewriter (
1528
+ MLIRContext *ctx, const ConversionConfig &config)
1523
1529
: PatternRewriter(ctx),
1524
- impl(new detail::ConversionPatternRewriterImpl(*this )) {
1530
+ impl(new detail::ConversionPatternRewriterImpl(*this , config )) {
1525
1531
setListener (impl.get ());
1526
1532
}
1527
1533
@@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
1972
1978
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
1973
1979
LLVM_DEBUG ({
1974
1980
logFailure (rewriterImpl.logger , " pattern failed to match" );
1975
- if (rewriterImpl.notifyCallback ) {
1981
+ if (rewriterImpl.config . notifyCallback ) {
1976
1982
Diagnostic diag (op->getLoc (), DiagnosticSeverity::Remark);
1977
1983
diag << " Failed to apply pattern \" " << pattern.getDebugName ()
1978
1984
<< " \" on op:\n "
1979
1985
<< *op;
1980
- rewriterImpl.notifyCallback (diag);
1986
+ rewriterImpl.config . notifyCallback (diag);
1981
1987
}
1982
1988
});
1983
1989
rewriterImpl.resetState (curState);
@@ -2365,14 +2371,12 @@ namespace mlir {
2365
2371
struct OperationConverter {
2366
2372
explicit OperationConverter (const ConversionTarget &target,
2367
2373
const FrozenRewritePatternSet &patterns,
2368
- OpConversionMode mode ,
2369
- DenseSet<Operation *> *trackedOps = nullptr )
2370
- : opLegalizer(target, patterns), mode(mode ), trackedOps(trackedOps ) {}
2374
+ const ConversionConfig &config ,
2375
+ OpConversionMode mode )
2376
+ : opLegalizer(target, patterns), config(config ), mode(mode ) {}
2371
2377
2372
2378
// / Converts the given operations to the conversion target.
2373
- LogicalResult
2374
- convertOperations (ArrayRef<Operation *> ops,
2375
- function_ref<void (Diagnostic &)> notifyCallback = nullptr );
2379
+ LogicalResult convertOperations (ArrayRef<Operation *> ops);
2376
2380
2377
2381
private:
2378
2382
// / Converts an operation with the given rewriter.
@@ -2409,14 +2413,11 @@ struct OperationConverter {
2409
2413
// / The legalizer to use when converting operations.
2410
2414
OperationLegalizer opLegalizer;
2411
2415
2416
+ // / Dialect conversion configuration.
2417
+ ConversionConfig config;
2418
+
2412
2419
// / The conversion mode to use when legalizing operations.
2413
2420
OpConversionMode mode;
2414
-
2415
- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2416
- // / this is populated with ops found to be legalizable to the target.
2417
- // / When mode == OpConversionMode::Partial, this is populated with ops found
2418
- // / *not* to be legalizable to the target.
2419
- DenseSet<Operation *> *trackedOps;
2420
2421
};
2421
2422
} // namespace mlir
2422
2423
@@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2430
2431
return op->emitError ()
2431
2432
<< " failed to legalize operation '" << op->getName () << " '" ;
2432
2433
// Partial conversions allow conversions to fail iff the operation was not
2433
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
2434
- // set, non-legalizable ops are included .
2434
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2435
+ // set, non-legalizable ops are added to that set .
2435
2436
if (mode == OpConversionMode::Partial) {
2436
2437
if (opLegalizer.isIllegal (op))
2437
2438
return op->emitError ()
2438
2439
<< " failed to legalize operation '" << op->getName ()
2439
2440
<< " ' that was explicitly marked illegal" ;
2440
- if (trackedOps )
2441
- trackedOps ->insert (op);
2441
+ if (config. unlegalizedOps )
2442
+ config. unlegalizedOps ->insert (op);
2442
2443
}
2443
2444
} else if (mode == OpConversionMode::Analysis) {
2444
2445
// Analysis conversions don't fail if any operations fail to legalize,
2445
2446
// they are only interested in the operations that were successfully
2446
2447
// legalized.
2447
- trackedOps->insert (op);
2448
+ if (config.legalizableOps )
2449
+ config.legalizableOps ->insert (op);
2448
2450
}
2449
2451
return success ();
2450
2452
}
2451
2453
2452
- LogicalResult OperationConverter::convertOperations (
2453
- ArrayRef<Operation *> ops,
2454
- function_ref<void (Diagnostic &)> notifyCallback) {
2454
+ LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2455
2455
if (ops.empty ())
2456
2456
return success ();
2457
2457
const ConversionTarget &target = opLegalizer.getTarget ();
@@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations(
2472
2472
}
2473
2473
2474
2474
// Convert each operation and discard rewrites on failure.
2475
- ConversionPatternRewriter rewriter (ops.front ()->getContext ());
2475
+ ConversionPatternRewriter rewriter (ops.front ()->getContext (), config );
2476
2476
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2477
- rewriterImpl.notifyCallback = notifyCallback;
2478
- rewriterImpl.trackedOps = trackedOps;
2479
2477
2480
2478
for (auto *op : toConvert)
2481
2479
if (failed (convert (rewriter, op)))
@@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3461
3459
// ===----------------------------------------------------------------------===//
3462
3460
// Partial Conversion
3463
3461
3464
- LogicalResult
3465
- mlir::applyPartialConversion (ArrayRef<Operation *> ops,
3466
- const ConversionTarget &target,
3467
- const FrozenRewritePatternSet &patterns,
3468
- DenseSet<Operation *> *unconvertedOps) {
3469
- OperationConverter opConverter (target, patterns, OpConversionMode::Partial,
3470
- unconvertedOps);
3462
+ LogicalResult mlir::applyPartialConversion (
3463
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
3464
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3465
+ OperationConverter opConverter (target, patterns, config,
3466
+ OpConversionMode::Partial);
3471
3467
return opConverter.convertOperations (ops);
3472
3468
}
3473
3469
LogicalResult
3474
3470
mlir::applyPartialConversion (Operation *op, const ConversionTarget &target,
3475
3471
const FrozenRewritePatternSet &patterns,
3476
- DenseSet<Operation *> *unconvertedOps) {
3477
- return applyPartialConversion (llvm::ArrayRef (op), target, patterns,
3478
- unconvertedOps);
3472
+ ConversionConfig config) {
3473
+ return applyPartialConversion (llvm::ArrayRef (op), target, patterns, config);
3479
3474
}
3480
3475
3481
3476
// ===----------------------------------------------------------------------===//
3482
3477
// Full Conversion
3483
3478
3484
- LogicalResult
3485
- mlir::applyFullConversion (ArrayRef<Operation *> ops, const ConversionTarget &target,
3486
- const FrozenRewritePatternSet &patterns) {
3487
- OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3479
+ LogicalResult mlir::applyFullConversion (ArrayRef<Operation *> ops,
3480
+ const ConversionTarget &target,
3481
+ const FrozenRewritePatternSet &patterns,
3482
+ ConversionConfig config) {
3483
+ OperationConverter opConverter (target, patterns, config,
3484
+ OpConversionMode::Full);
3488
3485
return opConverter.convertOperations (ops);
3489
3486
}
3490
- LogicalResult
3491
- mlir::applyFullConversion (Operation *op, const ConversionTarget &target,
3492
- const FrozenRewritePatternSet &patterns) {
3493
- return applyFullConversion (llvm::ArrayRef (op), target, patterns);
3487
+ LogicalResult mlir::applyFullConversion (Operation *op,
3488
+ const ConversionTarget &target,
3489
+ const FrozenRewritePatternSet &patterns,
3490
+ ConversionConfig config) {
3491
+ return applyFullConversion (llvm::ArrayRef (op), target, patterns, config);
3494
3492
}
3495
3493
3496
3494
// ===----------------------------------------------------------------------===//
3497
3495
// Analysis Conversion
3498
3496
3499
- LogicalResult
3500
- mlir::applyAnalysisConversion (ArrayRef<Operation *> ops,
3501
- ConversionTarget &target,
3502
- const FrozenRewritePatternSet &patterns,
3503
- DenseSet<Operation *> &convertedOps,
3504
- function_ref<void (Diagnostic &)> notifyCallback) {
3505
- OperationConverter opConverter (target, patterns, OpConversionMode::Analysis,
3506
- &convertedOps);
3507
- return opConverter.convertOperations (ops, notifyCallback);
3497
+ LogicalResult mlir::applyAnalysisConversion (
3498
+ ArrayRef<Operation *> ops, ConversionTarget &target,
3499
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3500
+ OperationConverter opConverter (target, patterns, config,
3501
+ OpConversionMode::Analysis);
3502
+ return opConverter.convertOperations (ops);
3508
3503
}
3509
3504
LogicalResult
3510
3505
mlir::applyAnalysisConversion (Operation *op, ConversionTarget &target,
3511
3506
const FrozenRewritePatternSet &patterns,
3512
- DenseSet<Operation *> &convertedOps,
3513
- function_ref<void (Diagnostic &)> notifyCallback) {
3514
- return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns,
3515
- convertedOps, notifyCallback);
3507
+ ConversionConfig config) {
3508
+ return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns, config);
3516
3509
}
0 commit comments