@@ -379,14 +379,20 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
379
379
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl (
380
380
Operation *target, StringRef passName,
381
381
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
382
- const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule ,
382
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule ,
383
383
const RaggedArray<MappedValue> &extraMappings,
384
384
const TransformOptions &options,
385
385
const Pass::Option<std::string> &transformFileName,
386
386
const Pass::Option<std::string> &transformLibraryFileName,
387
387
const Pass::Option<std::string> &debugPayloadRootTag,
388
388
const Pass::Option<std::string> &debugTransformRootTag,
389
389
StringRef binaryName) {
390
+ bool hasSharedTransformModule =
391
+ sharedTransformModule && *sharedTransformModule;
392
+ bool hasTransformLibraryModule =
393
+ transformLibraryModule && *transformLibraryModule;
394
+ assert ((!hasSharedTransformModule || !hasTransformLibraryModule) &&
395
+ " at most one of shared or library transform module can be set" );
390
396
391
397
// Step 1
392
398
// ------
@@ -407,9 +413,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
407
413
// transform is embedded in the payload IR. If debugTransformRootTag was
408
414
// passed, then we are in user-specified selection of the transforming IR.
409
415
// This corresponds to REPL debug mode.
410
- bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
411
416
Operation *transformContainer =
412
- sharedTransform ? sharedTransformModule->get () : target;
417
+ hasSharedTransformModule ? sharedTransformModule->get () : target;
413
418
Operation *transformRoot =
414
419
debugTransformRootTag.empty ()
415
420
? findTopLevelTransform (transformContainer,
@@ -430,7 +435,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
430
435
// Copy external defintions for symbols if provided. Be aware of potential
431
436
// concurrent execution (normally, the error shouldn't be triggered unless the
432
437
// transform IR modifies itself in a pass, which is also forbidden elsewhere).
433
- if (!sharedTransform && libraryModule && *libraryModule ) {
438
+ if (hasTransformLibraryModule ) {
434
439
if (!target->isProperAncestor (transformRoot)) {
435
440
InFlightDiagnostic diag =
436
441
transformRoot->emitError ()
@@ -439,7 +444,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
439
444
return diag;
440
445
}
441
446
if (failed (defineDeclaredSymbols (*transformRoot->getBlock (),
442
- libraryModule ->get ())))
447
+ transformLibraryModule ->get ())))
443
448
return failure ();
444
449
}
445
450
@@ -461,25 +466,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
461
466
LogicalResult transform::detail::interpreterBaseInitializeImpl (
462
467
MLIRContext *context, StringRef transformFileName,
463
468
StringRef transformLibraryFileName,
464
- std::shared_ptr<OwningOpRef<ModuleOp>> &module ,
465
- std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule ,
469
+ std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule ,
470
+ std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule ,
466
471
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
467
472
moduleBuilder) {
468
- OwningOpRef<ModuleOp> parsed;
469
- if (failed (parseTransformModuleFromFile (context, transformFileName, parsed)))
473
+ OwningOpRef<ModuleOp> parsedTransformModule;
474
+ if (failed (parseTransformModuleFromFile (context, transformFileName,
475
+ parsedTransformModule)))
470
476
return failure ();
471
- if (parsed && failed (mlir::verify (*parsed )))
477
+ if (parsedTransformModule && failed (mlir::verify (*parsedTransformModule )))
472
478
return failure ();
473
479
474
- OwningOpRef<ModuleOp> parsedLibrary ;
480
+ OwningOpRef<ModuleOp> parsedLibraryModule ;
475
481
if (failed (parseTransformModuleFromFile (context, transformLibraryFileName,
476
- parsedLibrary )))
482
+ parsedLibraryModule )))
477
483
return failure ();
478
- if (parsedLibrary && failed (mlir::verify (*parsedLibrary )))
484
+ if (parsedLibraryModule && failed (mlir::verify (*parsedLibraryModule )))
479
485
return failure ();
480
486
481
- if (parsed) {
482
- module = std::make_shared<OwningOpRef<ModuleOp>>(std::move (parsed));
487
+ if (parsedTransformModule) {
488
+ sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
489
+ std::move (parsedTransformModule));
483
490
} else if (moduleBuilder) {
484
491
// TODO: better location story.
485
492
auto location = UnknownLoc::get (context);
@@ -491,20 +498,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
491
498
if (std::optional<LogicalResult> result = moduleBuilder (b, location)) {
492
499
if (failed (*result))
493
500
return failure ();
494
- module = std::move (localModule);
501
+ sharedTransformModule = std::move (localModule);
495
502
}
496
503
}
497
504
498
- if (!parsedLibrary || !*parsedLibrary )
505
+ if (!parsedLibraryModule || !*parsedLibraryModule )
499
506
return success ();
500
507
501
- if (module && *module ) {
502
- if (failed (defineDeclaredSymbols (*module ->get ().getBody (),
503
- parsedLibrary .get ())))
508
+ if (sharedTransformModule && *sharedTransformModule ) {
509
+ if (failed (defineDeclaredSymbols (*sharedTransformModule ->get ().getBody (),
510
+ parsedLibraryModule .get ())))
504
511
return failure ();
505
512
} else {
506
- libraryModule =
507
- std::make_shared<OwningOpRef<ModuleOp>>(std::move (parsedLibrary ));
513
+ transformLibraryModule =
514
+ std::make_shared<OwningOpRef<ModuleOp>>(std::move (parsedLibraryModule ));
508
515
}
509
516
return success ();
510
517
}
0 commit comments