@@ -80,8 +80,8 @@ void Pass::printAsTextualPipeline(raw_ostream &os) {
80
80
namespace mlir {
81
81
namespace detail {
82
82
struct OpPassManagerImpl {
83
- OpPassManagerImpl (StringAttr identifier , OpPassManager::Nesting nesting)
84
- : name(identifier.str ()), identifier(identifier ),
83
+ OpPassManagerImpl (OperationName opName , OpPassManager::Nesting nesting)
84
+ : name(opName.getStringRef ()), opName(opName ),
85
85
initializationGeneration (0 ), nesting(nesting) {}
86
86
OpPassManagerImpl (StringRef name, OpPassManager::Nesting nesting)
87
87
: name(name), initializationGeneration(0 ), nesting(nesting) {}
@@ -102,23 +102,24 @@ struct OpPassManagerImpl {
102
102
// / preserved.
103
103
void clear ();
104
104
105
- // / Coalesce adjacent AdaptorPasses into one large adaptor. This runs
106
- // / recursively through the pipeline graph.
107
- void coalesceAdjacentAdaptorPasses ();
105
+ // / Finalize the pass list in preparation for execution. This includes
106
+ // / coalescing adjacent pass managers when possible, verifying scheduled
107
+ // / passes, etc.
108
+ LogicalResult finalizePassList (MLIRContext *ctx);
108
109
109
- // / Return the operation name of this pass manager as an identifier .
110
- StringAttr getOpName (MLIRContext &context) {
111
- if (!identifier )
112
- identifier = StringAttr::get (&context, name );
113
- return *identifier ;
110
+ // / Return the operation name of this pass manager.
111
+ OperationName getOpName (MLIRContext &context) {
112
+ if (!opName )
113
+ opName = OperationName (name, &context );
114
+ return *opName ;
114
115
}
115
116
116
117
// / The name of the operation that passes of this pass manager operate on.
117
118
std::string name;
118
119
119
- // / The cached identifier (internalized in the context) for the name of the
120
+ // / The cached OperationName (internalized in the context) for the name of the
120
121
// / operation that passes of this pass manager operate on.
121
- Optional<StringAttr> identifier ;
122
+ Optional<OperationName> opName ;
122
123
123
124
// / The set of passes to run as part of this pass manager.
124
125
std::vector<std::unique_ptr<Pass>> passes;
@@ -173,18 +174,12 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
173
174
174
175
void OpPassManagerImpl::clear () { passes.clear (); }
175
176
176
- void OpPassManagerImpl::coalesceAdjacentAdaptorPasses () {
177
- // Bail out early if there are no adaptor passes.
178
- if (llvm::none_of (passes, [](std::unique_ptr<Pass> &pass) {
179
- return isa<OpToOpPassAdaptor>(pass.get ());
180
- }))
181
- return ;
182
-
177
+ LogicalResult OpPassManagerImpl::finalizePassList (MLIRContext *ctx) {
183
178
// Walk the pass list and merge adjacent adaptors.
184
179
OpToOpPassAdaptor *lastAdaptor = nullptr ;
185
- for (auto &passe : passes) {
180
+ for (auto &pass : passes) {
186
181
// Check to see if this pass is an adaptor.
187
- if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(passe .get ())) {
182
+ if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass .get ())) {
188
183
// If it is the first adaptor in a possible chain, remember it and
189
184
// continue.
190
185
if (!lastAdaptor) {
@@ -194,25 +189,39 @@ void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
194
189
195
190
// Otherwise, merge into the existing adaptor and delete the current one.
196
191
currentAdaptor->mergeInto (*lastAdaptor);
197
- passe .reset ();
192
+ pass .reset ();
198
193
} else if (lastAdaptor) {
199
- // If this pass is not an adaptor, then coalesce and forget any existing
194
+ // If this pass is not an adaptor, then finalize and forget any existing
200
195
// adaptor.
201
196
for (auto &pm : lastAdaptor->getPassManagers ())
202
- pm.getImpl ().coalesceAdjacentAdaptorPasses ();
197
+ if (failed (pm.getImpl ().finalizePassList (ctx)))
198
+ return failure ();
203
199
lastAdaptor = nullptr ;
204
200
}
205
201
}
206
202
207
- // If there was an adaptor at the end of the manager, coalesce it as well.
203
+ // If there was an adaptor at the end of the manager, finalize it as well.
208
204
if (lastAdaptor) {
209
205
for (auto &pm : lastAdaptor->getPassManagers ())
210
- pm.getImpl ().coalesceAdjacentAdaptorPasses ();
206
+ if (failed (pm.getImpl ().finalizePassList (ctx)))
207
+ return failure ();
211
208
}
212
209
213
- // Now that the adaptors have been merged, erase the empty slot corresponding
210
+ // Now that the adaptors have been merged, erase any empty slots corresponding
214
211
// to the merged adaptors that were nulled-out in the loop above.
212
+ Optional<RegisteredOperationName> opName =
213
+ getOpName (*ctx).getRegisteredInfo ();
215
214
llvm::erase_if (passes, std::logical_not<std::unique_ptr<Pass>>());
215
+
216
+ // Verify that all of the passes are valid for the operation.
217
+ for (std::unique_ptr<Pass> &pass : passes) {
218
+ if (opName && !pass->canScheduleOn (*opName)) {
219
+ return emitError (UnknownLoc::get (ctx))
220
+ << " unable to schedule pass '" << pass->getName ()
221
+ << " ' on a PassManager intended to run on '" << name << " '!" ;
222
+ }
223
+ }
224
+ return success ();
216
225
}
217
226
218
227
// ===----------------------------------------------------------------------===//
@@ -279,7 +288,7 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
279
288
StringRef OpPassManager::getOpName () const { return impl->name ; }
280
289
281
290
// / Return the operation name that this pass manager operates on.
282
- StringAttr OpPassManager::getOpName (MLIRContext &context) const {
291
+ OperationName OpPassManager::getOpName (MLIRContext &context) const {
283
292
return impl->getOpName (context);
284
293
}
285
294
@@ -367,9 +376,9 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
367
376
" nested under the current operation the pass is processing" ;
368
377
assert (pipeline.getOpName () == root->getName ().getStringRef ());
369
378
370
- // Before running, make sure to coalesce any adjacent pass adaptors in the
371
- // pipeline.
372
- pipeline. getImpl (). coalesceAdjacentAdaptorPasses ();
379
+ // Before running, finalize the passes held by the pipeline.
380
+ if ( failed ( pipeline.getImpl (). finalizePassList (root-> getContext ())))
381
+ return failure ();
373
382
374
383
// Initialize the user provided pipeline and execute the pipeline.
375
384
if (failed (pipeline.initialize (root->getContext (), parentInitGeneration)))
@@ -468,7 +477,7 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
468
477
// / Find an operation pass manager that can operate on an operation of the given
469
478
// / type, or nullptr if one does not exist.
470
479
static OpPassManager *findPassManagerFor (MutableArrayRef<OpPassManager> mgrs,
471
- StringAttr name,
480
+ OperationName name,
472
481
MLIRContext &context) {
473
482
auto *it = llvm::find_if (
474
483
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName (context) == name; });
@@ -538,8 +547,7 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
538
547
for (auto ®ion : getOperation ()->getRegions ()) {
539
548
for (auto &block : region) {
540
549
for (auto &op : block) {
541
- auto *mgr = findPassManagerFor (mgrs, op.getName ().getIdentifier (),
542
- *op.getContext ());
550
+ auto *mgr = findPassManagerFor (mgrs, op.getName (), *op.getContext ());
543
551
if (!mgr)
544
552
continue ;
545
553
@@ -581,7 +589,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
581
589
for (auto &block : region) {
582
590
for (auto &op : block) {
583
591
// Add this operation iff the name matches any of the pass managers.
584
- if (findPassManagerFor (mgrs, op.getName (). getIdentifier () , *context))
592
+ if (findPassManagerFor (mgrs, op.getName (), *context))
585
593
opAMPairs.emplace_back (&op, am.nest (&op));
586
594
}
587
595
}
@@ -604,9 +612,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
604
612
unsigned pmIndex = it - activePMs.begin ();
605
613
606
614
// Get the pass manager for this operation and execute it.
607
- auto *pm =
608
- findPassManagerFor (asyncExecutors[pmIndex],
609
- opPMPair.first ->getName ().getIdentifier (), *context);
615
+ auto *pm = findPassManagerFor (asyncExecutors[pmIndex],
616
+ opPMPair.first ->getName (), *context);
610
617
assert (pm && " expected valid pass manager for operation" );
611
618
612
619
unsigned initGeneration = pm->impl ->initializationGeneration ;
@@ -641,21 +648,21 @@ void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
641
648
// / Run the passes within this manager on the provided operation.
642
649
LogicalResult PassManager::run (Operation *op) {
643
650
MLIRContext *context = getContext ();
644
- assert (op->getName (). getIdentifier () == getOpName (*context) &&
651
+ assert (op->getName () == getOpName (*context) &&
645
652
" operation has a different name than the PassManager or is from a "
646
653
" different context" );
647
654
648
- // Before running, make sure to coalesce any adjacent pass adaptors in the
649
- // pipeline.
650
- getImpl ().coalesceAdjacentAdaptorPasses ();
651
-
652
655
// Register all dialects for the current pipeline.
653
656
DialectRegistry dependentDialects;
654
657
getDependentDialects (dependentDialects);
655
658
context->appendDialectRegistry (dependentDialects);
656
659
for (StringRef name : dependentDialects.getDialectNames ())
657
660
context->getOrLoadDialect (name);
658
661
662
+ // Before running, make sure to finalize the pipeline pass list.
663
+ if (failed (getImpl ().finalizePassList (context)))
664
+ return failure ();
665
+
659
666
// Initialize all of the passes within the pass manager with a new generation.
660
667
llvm::hash_code newInitKey = context->getRegistryHash ();
661
668
if (newInitKey != initializationKey) {
0 commit comments