Skip to content

Commit 9748f98

Browse files
[mlir][transform] Make variable names in interpreter consistent. (NFC) (#67800)
This commit renames the arguments of several static implementation functions of the transform interpreter base class to match the names of the corresponding member variables in order to clarify their intent. Similarly, it renames some local variables to reflect their relationship with corresponding member variables. Finally, this commit also asserts in `interpreterBaseRunOnOperationImpl` that at most one of shared and library module are set (which the initialization function guarantees) and simplifies some related `if` conditions.
1 parent afe4006 commit 9748f98

File tree

1 file changed

+29
-22
lines changed

1 file changed

+29
-22
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -379,14 +379,20 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
379379
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
380380
Operation *target, StringRef passName,
381381
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
382-
const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
382+
const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
383383
const RaggedArray<MappedValue> &extraMappings,
384384
const TransformOptions &options,
385385
const Pass::Option<std::string> &transformFileName,
386386
const Pass::Option<std::string> &transformLibraryFileName,
387387
const Pass::Option<std::string> &debugPayloadRootTag,
388388
const Pass::Option<std::string> &debugTransformRootTag,
389389
StringRef binaryName) {
390+
bool hasSharedTransformModule =
391+
sharedTransformModule && *sharedTransformModule;
392+
bool hasTransformLibraryModule =
393+
transformLibraryModule && *transformLibraryModule;
394+
assert((!hasSharedTransformModule || !hasTransformLibraryModule) &&
395+
"at most one of shared or library transform module can be set");
390396

391397
// Step 1
392398
// ------
@@ -407,9 +413,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
407413
// transform is embedded in the payload IR. If debugTransformRootTag was
408414
// passed, then we are in user-specified selection of the transforming IR.
409415
// This corresponds to REPL debug mode.
410-
bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
411416
Operation *transformContainer =
412-
sharedTransform ? sharedTransformModule->get() : target;
417+
hasSharedTransformModule ? sharedTransformModule->get() : target;
413418
Operation *transformRoot =
414419
debugTransformRootTag.empty()
415420
? findTopLevelTransform(transformContainer,
@@ -430,7 +435,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
430435
// Copy external defintions for symbols if provided. Be aware of potential
431436
// concurrent execution (normally, the error shouldn't be triggered unless the
432437
// transform IR modifies itself in a pass, which is also forbidden elsewhere).
433-
if (!sharedTransform && libraryModule && *libraryModule) {
438+
if (hasTransformLibraryModule) {
434439
if (!target->isProperAncestor(transformRoot)) {
435440
InFlightDiagnostic diag =
436441
transformRoot->emitError()
@@ -439,7 +444,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
439444
return diag;
440445
}
441446
if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
442-
libraryModule->get())))
447+
transformLibraryModule->get())))
443448
return failure();
444449
}
445450

@@ -461,25 +466,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
461466
LogicalResult transform::detail::interpreterBaseInitializeImpl(
462467
MLIRContext *context, StringRef transformFileName,
463468
StringRef transformLibraryFileName,
464-
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
465-
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
469+
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
470+
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
466471
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
467472
moduleBuilder) {
468-
OwningOpRef<ModuleOp> parsed;
469-
if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
473+
OwningOpRef<ModuleOp> parsedTransformModule;
474+
if (failed(parseTransformModuleFromFile(context, transformFileName,
475+
parsedTransformModule)))
470476
return failure();
471-
if (parsed && failed(mlir::verify(*parsed)))
477+
if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
472478
return failure();
473479

474-
OwningOpRef<ModuleOp> parsedLibrary;
480+
OwningOpRef<ModuleOp> parsedLibraryModule;
475481
if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
476-
parsedLibrary)))
482+
parsedLibraryModule)))
477483
return failure();
478-
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
484+
if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
479485
return failure();
480486

481-
if (parsed) {
482-
module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
487+
if (parsedTransformModule) {
488+
sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
489+
std::move(parsedTransformModule));
483490
} else if (moduleBuilder) {
484491
// TODO: better location story.
485492
auto location = UnknownLoc::get(context);
@@ -491,20 +498,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
491498
if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
492499
if (failed(*result))
493500
return failure();
494-
module = std::move(localModule);
501+
sharedTransformModule = std::move(localModule);
495502
}
496503
}
497504

498-
if (!parsedLibrary || !*parsedLibrary)
505+
if (!parsedLibraryModule || !*parsedLibraryModule)
499506
return success();
500507

501-
if (module && *module) {
502-
if (failed(defineDeclaredSymbols(*module->get().getBody(),
503-
parsedLibrary.get())))
508+
if (sharedTransformModule && *sharedTransformModule) {
509+
if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
510+
parsedLibraryModule.get())))
504511
return failure();
505512
} else {
506-
libraryModule =
507-
std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
513+
transformLibraryModule =
514+
std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
508515
}
509516
return success();
510517
}

0 commit comments

Comments
 (0)