@@ -230,6 +230,8 @@ class IRRewrite {
230
230
// / Erase the given block (unless it was already erased).
231
231
void eraseBlock (Block *block);
232
232
233
+ const ConversionConfig &getConfig () const ;
234
+
233
235
const Kind kind;
234
236
ConversionPatternRewriterImpl &rewriterImpl;
235
237
};
@@ -735,8 +737,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
735
737
namespace mlir {
736
738
namespace detail {
737
739
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
738
- explicit ConversionPatternRewriterImpl (PatternRewriter &rewriter)
739
- : eraseRewriter(rewriter.getContext()) {}
740
+ explicit ConversionPatternRewriterImpl (MLIRContext *ctx,
741
+ const ConversionConfig &config)
742
+ : eraseRewriter(ctx), config(config) {}
740
743
741
744
// ===--------------------------------------------------------------------===//
742
745
// State Management
@@ -936,14 +939,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
936
939
// / converting the arguments of blocks within that region.
937
940
DenseMap<Region *, const TypeConverter *> regionToConverter;
938
941
939
- // / This allows the user to collect the match failure message.
940
- function_ref<void (Diagnostic &)> notifyCallback;
941
-
942
- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
943
- // / this is populated with ops found to be legalizable to the target.
944
- // / When mode == OpConversionMode::Partial, this is populated with ops found
945
- // / *not* to be legalizable to the target.
946
- DenseSet<Operation *> *trackedOps = nullptr ;
942
+ // / Dialect conversion configuration.
943
+ const ConversionConfig &config;
947
944
948
945
#ifndef NDEBUG
949
946
// / A set of operations that have pending updates. This tracking isn't
@@ -966,6 +963,10 @@ void IRRewrite::eraseBlock(Block *block) {
966
963
rewriterImpl.eraseRewriter .eraseBlock (block);
967
964
}
968
965
966
+ const ConversionConfig &IRRewrite::getConfig () const {
967
+ return rewriterImpl.config ;
968
+ }
969
+
969
970
void BlockTypeConversionRewrite::commit () {
970
971
// Process the remapping for each of the original arguments.
971
972
for (auto [origArg, info] :
@@ -1085,8 +1086,8 @@ void ReplaceOperationRewrite::commit() {
1085
1086
if (Value newValue =
1086
1087
rewriterImpl.mapping .lookupOrNull (result, result.getType ()))
1087
1088
result.replaceAllUsesWith (newValue);
1088
- if (rewriterImpl. trackedOps )
1089
- rewriterImpl. trackedOps ->erase (op);
1089
+ if (getConfig (). unlegalizedOps )
1090
+ getConfig (). unlegalizedOps ->erase (op);
1090
1091
// Do not erase the operation yet. It may still be referenced in `mapping`.
1091
1092
op->getBlock ()->getOperations ().remove (op);
1092
1093
}
@@ -1514,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
1514
1515
Diagnostic diag (loc, DiagnosticSeverity::Remark);
1515
1516
reasonCallback (diag);
1516
1517
logger.startLine () << " ** Failure : " << diag.str () << " \n " ;
1517
- if (notifyCallback)
1518
- notifyCallback (diag);
1518
+ if (config. notifyCallback )
1519
+ config. notifyCallback (diag);
1519
1520
});
1520
1521
}
1521
1522
1522
1523
// ===----------------------------------------------------------------------===//
1523
1524
// ConversionPatternRewriter
1524
1525
// ===----------------------------------------------------------------------===//
1525
1526
1526
- ConversionPatternRewriter::ConversionPatternRewriter (MLIRContext *ctx)
1527
+ ConversionPatternRewriter::ConversionPatternRewriter (
1528
+ MLIRContext *ctx, const ConversionConfig &config)
1527
1529
: PatternRewriter(ctx),
1528
- impl(new detail::ConversionPatternRewriterImpl(* this )) {
1530
+ impl(new detail::ConversionPatternRewriterImpl(ctx, config )) {
1529
1531
setListener (impl.get ());
1530
1532
}
1531
1533
@@ -1994,12 +1996,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
1994
1996
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
1995
1997
LLVM_DEBUG ({
1996
1998
logFailure (rewriterImpl.logger , " pattern failed to match" );
1997
- if (rewriterImpl.notifyCallback ) {
1999
+ if (rewriterImpl.config . notifyCallback ) {
1998
2000
Diagnostic diag (op->getLoc (), DiagnosticSeverity::Remark);
1999
2001
diag << " Failed to apply pattern \" " << pattern.getDebugName ()
2000
2002
<< " \" on op:\n "
2001
2003
<< *op;
2002
- rewriterImpl.notifyCallback (diag);
2004
+ rewriterImpl.config . notifyCallback (diag);
2003
2005
}
2004
2006
});
2005
2007
rewriterImpl.resetState (curState);
@@ -2387,14 +2389,12 @@ namespace mlir {
2387
2389
struct OperationConverter {
2388
2390
explicit OperationConverter (const ConversionTarget &target,
2389
2391
const FrozenRewritePatternSet &patterns,
2390
- OpConversionMode mode ,
2391
- DenseSet<Operation *> *trackedOps = nullptr )
2392
- : opLegalizer(target, patterns), mode(mode ), trackedOps(trackedOps ) {}
2392
+ const ConversionConfig &config ,
2393
+ OpConversionMode mode )
2394
+ : opLegalizer(target, patterns), config(config ), mode(mode ) {}
2393
2395
2394
2396
// / Converts the given operations to the conversion target.
2395
- LogicalResult
2396
- convertOperations (ArrayRef<Operation *> ops,
2397
- function_ref<void (Diagnostic &)> notifyCallback = nullptr );
2397
+ LogicalResult convertOperations (ArrayRef<Operation *> ops);
2398
2398
2399
2399
private:
2400
2400
// / Converts an operation with the given rewriter.
@@ -2431,14 +2431,11 @@ struct OperationConverter {
2431
2431
// / The legalizer to use when converting operations.
2432
2432
OperationLegalizer opLegalizer;
2433
2433
2434
+ // / Dialect conversion configuration.
2435
+ ConversionConfig config;
2436
+
2434
2437
// / The conversion mode to use when legalizing operations.
2435
2438
OpConversionMode mode;
2436
-
2437
- // / A set of pre-existing operations. When mode == OpConversionMode::Analysis,
2438
- // / this is populated with ops found to be legalizable to the target.
2439
- // / When mode == OpConversionMode::Partial, this is populated with ops found
2440
- // / *not* to be legalizable to the target.
2441
- DenseSet<Operation *> *trackedOps;
2442
2439
};
2443
2440
} // namespace mlir
2444
2441
@@ -2452,28 +2449,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
2452
2449
return op->emitError ()
2453
2450
<< " failed to legalize operation '" << op->getName () << " '" ;
2454
2451
// Partial conversions allow conversions to fail iff the operation was not
2455
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
2456
- // set, non-legalizable ops are included .
2452
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
2453
+ // set, non-legalizable ops are added to that set .
2457
2454
if (mode == OpConversionMode::Partial) {
2458
2455
if (opLegalizer.isIllegal (op))
2459
2456
return op->emitError ()
2460
2457
<< " failed to legalize operation '" << op->getName ()
2461
2458
<< " ' that was explicitly marked illegal" ;
2462
- if (trackedOps )
2463
- trackedOps ->insert (op);
2459
+ if (config. unlegalizedOps )
2460
+ config. unlegalizedOps ->insert (op);
2464
2461
}
2465
2462
} else if (mode == OpConversionMode::Analysis) {
2466
2463
// Analysis conversions don't fail if any operations fail to legalize,
2467
2464
// they are only interested in the operations that were successfully
2468
2465
// legalized.
2469
- trackedOps->insert (op);
2466
+ if (config.legalizableOps )
2467
+ config.legalizableOps ->insert (op);
2470
2468
}
2471
2469
return success ();
2472
2470
}
2473
2471
2474
- LogicalResult OperationConverter::convertOperations (
2475
- ArrayRef<Operation *> ops,
2476
- function_ref<void (Diagnostic &)> notifyCallback) {
2472
+ LogicalResult OperationConverter::convertOperations (ArrayRef<Operation *> ops) {
2477
2473
if (ops.empty ())
2478
2474
return success ();
2479
2475
const ConversionTarget &target = opLegalizer.getTarget ();
@@ -2494,10 +2490,8 @@ LogicalResult OperationConverter::convertOperations(
2494
2490
}
2495
2491
2496
2492
// Convert each operation and discard rewrites on failure.
2497
- ConversionPatternRewriter rewriter (ops.front ()->getContext ());
2493
+ ConversionPatternRewriter rewriter (ops.front ()->getContext (), config );
2498
2494
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl ();
2499
- rewriterImpl.notifyCallback = notifyCallback;
2500
- rewriterImpl.trackedOps = trackedOps;
2501
2495
2502
2496
for (auto *op : toConvert)
2503
2497
if (failed (convert (rewriter, op)))
@@ -3484,57 +3478,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
3484
3478
// ===----------------------------------------------------------------------===//
3485
3479
// Partial Conversion
3486
3480
3487
- LogicalResult
3488
- mlir::applyPartialConversion (ArrayRef<Operation *> ops,
3489
- const ConversionTarget &target,
3490
- const FrozenRewritePatternSet &patterns,
3491
- DenseSet<Operation *> *unconvertedOps) {
3492
- OperationConverter opConverter (target, patterns, OpConversionMode::Partial,
3493
- unconvertedOps);
3481
+ LogicalResult mlir::applyPartialConversion (
3482
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
3483
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3484
+ OperationConverter opConverter (target, patterns, config,
3485
+ OpConversionMode::Partial);
3494
3486
return opConverter.convertOperations (ops);
3495
3487
}
3496
3488
LogicalResult
3497
3489
mlir::applyPartialConversion (Operation *op, const ConversionTarget &target,
3498
3490
const FrozenRewritePatternSet &patterns,
3499
- DenseSet<Operation *> *unconvertedOps) {
3500
- return applyPartialConversion (llvm::ArrayRef (op), target, patterns,
3501
- unconvertedOps);
3491
+ ConversionConfig config) {
3492
+ return applyPartialConversion (llvm::ArrayRef (op), target, patterns, config);
3502
3493
}
3503
3494
3504
3495
// ===----------------------------------------------------------------------===//
3505
3496
// Full Conversion
3506
3497
3507
- LogicalResult
3508
- mlir::applyFullConversion (ArrayRef<Operation *> ops,
3509
- const ConversionTarget &target,
3510
- const FrozenRewritePatternSet &patterns) {
3511
- OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3498
+ LogicalResult mlir::applyFullConversion (ArrayRef<Operation *> ops,
3499
+ const ConversionTarget &target,
3500
+ const FrozenRewritePatternSet &patterns,
3501
+ ConversionConfig config) {
3502
+ OperationConverter opConverter (target, patterns, config,
3503
+ OpConversionMode::Full);
3512
3504
return opConverter.convertOperations (ops);
3513
3505
}
3514
- LogicalResult
3515
- mlir::applyFullConversion (Operation *op, const ConversionTarget &target,
3516
- const FrozenRewritePatternSet &patterns) {
3517
- return applyFullConversion (llvm::ArrayRef (op), target, patterns);
3506
+ LogicalResult mlir::applyFullConversion (Operation *op,
3507
+ const ConversionTarget &target,
3508
+ const FrozenRewritePatternSet &patterns,
3509
+ ConversionConfig config) {
3510
+ return applyFullConversion (llvm::ArrayRef (op), target, patterns, config);
3518
3511
}
3519
3512
3520
3513
// ===----------------------------------------------------------------------===//
3521
3514
// Analysis Conversion
3522
3515
3523
- LogicalResult
3524
- mlir::applyAnalysisConversion (ArrayRef<Operation *> ops,
3525
- ConversionTarget &target,
3526
- const FrozenRewritePatternSet &patterns,
3527
- DenseSet<Operation *> &convertedOps,
3528
- function_ref<void (Diagnostic &)> notifyCallback) {
3529
- OperationConverter opConverter (target, patterns, OpConversionMode::Analysis,
3530
- &convertedOps);
3531
- return opConverter.convertOperations (ops, notifyCallback);
3516
+ LogicalResult mlir::applyAnalysisConversion (
3517
+ ArrayRef<Operation *> ops, ConversionTarget &target,
3518
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3519
+ OperationConverter opConverter (target, patterns, config,
3520
+ OpConversionMode::Analysis);
3521
+ return opConverter.convertOperations (ops);
3532
3522
}
3533
3523
LogicalResult
3534
3524
mlir::applyAnalysisConversion (Operation *op, ConversionTarget &target,
3535
3525
const FrozenRewritePatternSet &patterns,
3536
- DenseSet<Operation *> &convertedOps,
3537
- function_ref<void (Diagnostic &)> notifyCallback) {
3538
- return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns,
3539
- convertedOps, notifyCallback);
3526
+ ConversionConfig config) {
3527
+ return applyAnalysisConversion (llvm::ArrayRef (op), target, patterns, config);
3540
3528
}
0 commit comments