Skip to content

Commit 99c15eb

Browse files
[mlir][transform] Handle multiple library preloading passes. (#69705)
This is a new attempt at #69320. The transform dialect stores a "library module" that the preload pass can populate. Until now, each pass registered an additional module by simply pushing it to a vector; however, the interpreter only used the first of them. This commit turns the registration into "loading", i.e., each newly added module gets merged into the existing one. This allows the loading to be split into several passes, and using the library in the interpreter now takes all of them into account. While this design avoids repeated merging every time the library is accessed, it requires that the implementation of merging modules lives in the TransformDialect target (since it at the dialect depend on each other). This resolves #69111.
1 parent d25e0aa commit 99c15eb

File tree

11 files changed

+343
-269
lines changed

11 files changed

+343
-269
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,23 @@ def Transform_Dialect : Dialect {
7878
using ExtensionTypePrintingHook =
7979
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
8080

81-
/// Appends the given module as a transform symbol library available to
82-
/// all dialect users.
83-
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
84-
library) {
85-
libraryModules.push_back(std::move(library));
86-
}
87-
88-
/// Returns a range of registered library modules.
89-
auto getLibraryModules() const {
90-
return ::llvm::map_range(
91-
libraryModules,
92-
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
93-
return library.get();
94-
});
81+
/// Loads the given module into the transform symbol library module.
82+
LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
83+
library);
84+
85+
/// Returns the transform symbol library module available to all dialect
86+
/// users.
87+
ModuleOp getLibraryModule() const {
88+
if (libraryModule)
89+
return libraryModule.get();
90+
return ModuleOp();
9591
}
9692

9793
private:
94+
/// Initializes the transform symbol library module. Must be called from
95+
/// `TransformDialect::initialize` for the library module to work.
96+
void initializeLibraryModule();
97+
9898
/// Registers operations specified as template parameters with this
9999
/// dialect. Checks that they implement the required interfaces.
100100
template <typename... OpTys>
@@ -153,10 +153,9 @@ def Transform_Dialect : Dialect {
153153
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
154154
typePrintingHooks;
155155

156-
/// Modules containing symbols, e.g. named sequences, that will be
157-
/// resolved by the interpreter when used.
158-
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
159-
libraryModules;
156+
/// Module containing symbols, e.g. named sequences, that will be resolved
157+
/// by the interpreter when used.
158+
::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
160159
}];
161160
}
162161

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- Utils.h - Utils related to the transform dialect ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IR_UTILS_H
10+
#define MLIR_DIALECT_TRANSFORM_IR_UTILS_H
11+
12+
namespace mlir {
13+
class InFlightDiagnostic;
14+
class Operation;
15+
template <typename>
16+
class OwningOpRef;
17+
18+
namespace transform {
19+
namespace detail {
20+
21+
/// Merge all symbols from `other` into `target`. Both ops need to implement the
22+
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
23+
/// modified by this function and might not verify after the function returns.
24+
/// Upon merging, private symbols may be renamed in order to avoid collisions in
25+
/// the result. Public symbols may not collide, with the exception of
26+
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
27+
/// one of the two is external, in which case the other op preserved (or any one
28+
/// of the two if both are external).
29+
// TODO: Reconsider cloning individual ops rather than forcing users of the
30+
// function to clone (or move) `other` in order to improve efficiency.
31+
// This might primarily make sense if we can also prune the symbols that
32+
// are merged to a subset (such as those that are actually used).
33+
InFlightDiagnostic mergeSymbolsInto(Operation *target,
34+
OwningOpRef<Operation *> other);
35+
36+
} // namespace detail
37+
} // namespace transform
38+
} // namespace mlir
39+
40+
#endif // MLIR_DIALECT_TRANSFORM_IR_UTILS_H

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace detail {
3232
/// Each entry in `paths` may either be a regular file, in which case it ends up
3333
/// in the result list, or a directory, in which case all (regular) `.mlir`
3434
/// files in that directory are added. Any other file types lead to a failure.
35-
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
35+
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> paths,
3636
MLIRContext *context,
3737
SmallVectorImpl<std::string> &fileNames);
3838

@@ -69,20 +69,6 @@ TransformOpInterface findTransformEntryPoint(
6969
Operation *root, ModuleOp module,
7070
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);
7171

72-
/// Merge all symbols from `other` into `target`. Both ops need to implement the
73-
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
74-
/// modified by this function and might not verify after the function returns.
75-
/// Upon merging, private symbols may be renamed in order to avoid collisions in
76-
/// the result. Public symbols may not collide, with the exception of
77-
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
78-
/// one of the two is external, in which case the other op preserved (or any one
79-
/// of the two if both are external).
80-
// TODO: Reconsider cloning individual ops rather than forcing users of the
81-
// function to clone (or move) `other` in order to improve efficiency.
82-
// This might primarily make sense if we can also prune the symbols that
83-
// are merged to a subset (such as those that are actually used).
84-
InFlightDiagnostic mergeSymbolsInto(Operation *target,
85-
OwningOpRef<Operation *> other);
8672
} // namespace detail
8773

8874
/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.

mlir/lib/Dialect/Transform/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRTransformDialect
55
TransformInterfaces.cpp
66
TransformOps.cpp
77
TransformTypes.cpp
8+
Utils.cpp
89

910
DEPENDS
1011
MLIRMatchInterfacesIncGen

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1212
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1313
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
14+
#include "mlir/Dialect/Transform/IR/Utils.h"
1415
#include "mlir/IR/DialectImplementation.h"
1516
#include "llvm/ADT/SCCIterator.h"
1617

@@ -65,6 +66,7 @@ void transform::TransformDialect::initialize() {
6566
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
6667
>();
6768
initializeTypes();
69+
initializeLibraryModule();
6870
}
6971

7072
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
@@ -89,6 +91,20 @@ void transform::TransformDialect::printType(Type type,
8991
it->getSecond()(type, printer);
9092
}
9193

94+
LogicalResult transform::TransformDialect::loadIntoLibraryModule(
95+
::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
96+
return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
97+
}
98+
99+
void transform::TransformDialect::initializeLibraryModule() {
100+
MLIRContext *context = getContext();
101+
auto loc =
102+
FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
103+
libraryModule = ModuleOp::create(loc, "__transform_library");
104+
libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
105+
UnitAttr::get(context));
106+
}
107+
92108
void transform::TransformDialect::reportDuplicateTypeRegistration(
93109
StringRef mnemonic) {
94110
std::string buffer;

0 commit comments

Comments
 (0)