Skip to content

[mlir][transform] Handle multiple library preloading passes. #69705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,23 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

/// Appends the given module as a transform symbol library available to
/// all dialect users.
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library) {
libraryModules.push_back(std::move(library));
}

/// Returns a range of registered library modules.
auto getLibraryModules() const {
return ::llvm::map_range(
libraryModules,
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
return library.get();
});
/// Loads the given module into the transform symbol library module.
LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library);

/// Returns the transform symbol library module available to all dialect
/// users.
ModuleOp getLibraryModule() const {
if (libraryModule)
return libraryModule.get();
return ModuleOp();
}

private:
/// Initializes the transform symbol library module. Must be called from
/// `TransformDialect::initialize` for the library module to work.
void initializeLibraryModule();

/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
template <typename... OpTys>
Expand Down Expand Up @@ -153,10 +153,9 @@ def Transform_Dialect : Dialect {
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;

/// Modules containing symbols, e.g. named sequences, that will be
/// resolved by the interpreter when used.
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
libraryModules;
/// Module containing symbols, e.g. named sequences, that will be resolved
/// by the interpreter when used.
::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
}];
}

Expand Down
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===- Utils.h - Utils related to the transform dialect ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IR_UTILS_H
#define MLIR_DIALECT_TRANSFORM_IR_UTILS_H

namespace mlir {
class InFlightDiagnostic;
class Operation;
template <typename>
class OwningOpRef;

namespace transform {
namespace detail {

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
InFlightDiagnostic mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);

} // namespace detail
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_IR_UTILS_H
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace detail {
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> paths,
MLIRContext *context,
SmallVectorImpl<std::string> &fileNames);

Expand Down Expand Up @@ -69,20 +69,6 @@ TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
InFlightDiagnostic mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTransformDialect
TransformInterfaces.cpp
TransformOps.cpp
TransformTypes.cpp
Utils.cpp

DEPENDS
MLIRMatchInterfacesIncGen
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SCCIterator.h"

Expand Down Expand Up @@ -65,6 +66,7 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
initializeLibraryModule();
}

Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
Expand All @@ -89,6 +91,20 @@ void transform::TransformDialect::printType(Type type,
it->getSecond()(type, printer);
}

LogicalResult transform::TransformDialect::loadIntoLibraryModule(
::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
}

void transform::TransformDialect::initializeLibraryModule() {
MLIRContext *context = getContext();
auto loc =
FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
libraryModule = ModuleOp::create(loc, "__transform_library");
libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
UnitAttr::get(context));
}

void transform::TransformDialect::reportDuplicateTypeRegistration(
StringRef mnemonic) {
std::string buffer;
Expand Down
Loading