Skip to content

[mlir][transform] Handle multiple library preloading passes #69320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,62 @@ class BuildOnly : public DerivedTy {
BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
};

namespace detail {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this go to some new Utils header? This doesn't look like stuff that belongs to the dialect definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense. I realized that I only have to move the mergeSymbolsInto function (and related static functions). The remaining ones (for expanding paths etc) are arguably more related to the interpreter pass than this one, so I am leaving them where they are.

For now, I am calling the files Utils.h/Utils.cpp. All other files are called TransformSomething.xxx -- should I rename the new files to TransformUtils.xxx as well to follow that scheme?


/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why take a reference to ArrayRef, which itself is a reference as the name indicates?

(I see the code is moved from elsewhere, but IIRC it's yours anyway).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No good reason. I probably changed the type from something else and overlooked the & when I changed it to ArrayRef. Fixed in new PR.

MLIRContext *context,
SmallVectorImpl<std::string> &fileNames);

/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to parse, verify, aggregate and link the content of all mlir files
/// nested under `transformLibraryPaths` and containing transform dialect
/// specifications.
LogicalResult
assembleTransformLibraryFromPaths(MLIRContext *context,
ArrayRef<std::string> transformLibraryPaths,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
/// from file is expected to be prohibitively expensive.
/// In such cases, the transform module is expected to be found in the preloaded
/// library modules of the transform dialect.
/// Returns null if the module is not found.
ModuleOp getPreloadedTransformModule(MLIRContext *context);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);

/// Merge all symbols from `others` into `target`. See overload of
/// `mergeSymbolsInto` on one `other` op for details.
LogicalResult
mergeSymbolsInto(Operation *target,
MutableArrayRef<OwningOpRef<Operation *>> others);
} // namespace detail
} // namespace transform
} // namespace mlir

Expand Down
34 changes: 16 additions & 18 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,19 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

/// Appends the given module as a transform symbol library available to
/// all dialect users.
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library) {
libraryModules.push_back(std::move(library));
}

/// Returns a range of registered library modules.
auto getLibraryModules() const {
return ::llvm::map_range(
libraryModules,
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
return library.get();
});
/// Loads the given module into the transform symbol library module.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation looks copy-pasted from below.

void initializeLibraryModule();

/// Loads the given module into the transform symbol library module.
LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library);

/// Returns the transform symbol library module available to all dialect
/// users.
ModuleOp getLibraryModule() const {
if (libraryModule)
return libraryModule.get();
return ModuleOp();
}

private:
Expand Down Expand Up @@ -153,10 +152,9 @@ def Transform_Dialect : Dialect {
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;

/// Modules containing symbols, e.g. named sequences, that will be
/// resolved by the interpreter when used.
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
libraryModules;
/// Module containing symbols, e.g. named sequences, that will be resolved
/// by the interpreter when used.
::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,6 @@ class Region;

namespace transform {
namespace detail {

/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
MLIRContext *context,
SmallVectorImpl<std::string> &fileNames);

/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to parse, verify, aggregate and link the content of all mlir files
/// nested under `transformLibraryPaths` and containing transform dialect
/// specifications.
LogicalResult
assembleTransformLibraryFromPaths(MLIRContext *context,
ArrayRef<std::string> transformLibraryPaths,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
/// from file is expected to be prohibitively expensive.
/// In such cases, the transform module is expected to be found in the preloaded
/// library modules of the transform dialect.
/// Returns null if the module is not found.
ModuleOp getPreloadedTransformModule(MLIRContext *context);

/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
/// that is either:
/// 1. nested under `root` (takes precedence).
Expand All @@ -68,21 +34,6 @@ ModuleOp getPreloadedTransformModule(MLIRContext *context);
TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `entryPoint` to the payload.
Expand Down
Loading