Skip to content

Commit 822d127

Browse files
[mlir][transform] Handle multiple library preloading passes.
The transform dialect stores a "library module" that the preload pass can populate. Until now, each pass registered an additional module by simply pushing it to a vector; however, the interpreter only used the first of them. This commit turns the registration into "loading", i.e., each newly added module gets merged into the existing one. This allows the loading to be split into several passes, and using the library in the interpreter now takes all of them into account. While this design avoids repeated merging every time the library is accessed, it requires that the implementation of merging modules lives in the `TransformDialect` target (since it at the dialect depend on each other). This resolves #69111.
1 parent d2fa236 commit 822d127

File tree

6 files changed

+53
-26
lines changed

6 files changed

+53
-26
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,19 @@ def Transform_Dialect : Dialect {
7878
using ExtensionTypePrintingHook =
7979
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;
8080

81-
/// Appends the given module as a transform symbol library available to
82-
/// all dialect users.
83-
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
84-
library) {
85-
libraryModules.push_back(std::move(library));
86-
}
87-
88-
/// Returns a range of registered library modules.
89-
auto getLibraryModules() const {
90-
return ::llvm::map_range(
91-
libraryModules,
92-
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
93-
return library.get();
94-
});
81+
/// Loads the given module into the transform symbol library module.
82+
void initializeLibraryModule();
83+
84+
/// Loads the given module into the transform symbol library module.
85+
LogicalResult loadIntoLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
86+
library);
87+
88+
/// Returns the transform symbol library module available to all dialect
89+
/// users.
90+
ModuleOp getLibraryModule() const {
91+
if (libraryModule)
92+
return libraryModule.get();
93+
return ModuleOp();
9594
}
9695

9796
private:
@@ -153,10 +152,9 @@ def Transform_Dialect : Dialect {
153152
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
154153
typePrintingHooks;
155154

156-
/// Modules containing symbols, e.g. named sequences, that will be
157-
/// resolved by the interpreter when used.
158-
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
159-
libraryModules;
155+
/// Module containing symbols, e.g. named sequences, that will be resolved
156+
/// by the interpreter when used.
157+
::mlir::OwningOpRef<::mlir::ModuleOp> libraryModule;
160158
}];
161159
}
162160

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1212
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1313
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
14+
#include "mlir/Dialect/Transform/IR/Utils.h"
1415
#include "mlir/IR/DialectImplementation.h"
1516
#include "llvm/ADT/SCCIterator.h"
1617

@@ -65,6 +66,7 @@ void transform::TransformDialect::initialize() {
6566
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
6667
>();
6768
initializeTypes();
69+
initializeLibraryModule();
6870
}
6971

7072
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
@@ -89,6 +91,20 @@ void transform::TransformDialect::printType(Type type,
8991
it->getSecond()(type, printer);
9092
}
9193

94+
void transform::TransformDialect::initializeLibraryModule() {
95+
MLIRContext *context = getContext();
96+
auto loc =
97+
FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
98+
libraryModule = ModuleOp::create(loc, "__transform_library");
99+
libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
100+
UnitAttr::get(context));
101+
}
102+
103+
LogicalResult transform::TransformDialect::loadIntoLibraryModule(
104+
::mlir::OwningOpRef<::mlir::ModuleOp> &&library) {
105+
return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
106+
}
107+
92108
void transform::TransformDialect::reportDuplicateTypeRegistration(
93109
StringRef mnemonic) {
94110
std::string buffer;

mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ class PreloadLibraryPass
3333
// TODO: investigate using a resource blob if some ownership mode allows it.
3434
auto *dialect =
3535
getContext().getOrLoadDialect<transform::TransformDialect>();
36-
dialect->registerLibraryModule(std::move(mergedParsedLibraries));
36+
if (failed(
37+
dialect->loadIntoLibraryModule(std::move(mergedParsedLibraries))))
38+
signalPassFailure();
3739
}
3840
};
3941
} // namespace

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,8 @@ LogicalResult transform::detail::parseTransformModuleFromFile(
107107
}
108108

109109
ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
110-
auto preloadedLibraryRange =
111-
context->getOrLoadDialect<transform::TransformDialect>()
112-
->getLibraryModules();
113-
if (!preloadedLibraryRange.empty())
114-
return *preloadedLibraryRange.begin();
115-
return ModuleOp();
110+
return context->getOrLoadDialect<transform::TransformDialect>()
111+
->getLibraryModule();
116112
}
117113

118114
transform::TransformOpInterface

mlir/test/Dialect/Transform/preload-library.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
// RUN: -transform-interpreter=entry-point=private_helper \
44
// RUN: -split-input-file -verify-diagnostics
55

6+
// RUN: mlir-opt %s \
7+
// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
8+
// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
9+
// RUN: -transform-interpreter=entry-point=private_helper \
10+
// RUN: -split-input-file -verify-diagnostics
11+
12+
// RUN: mlir-opt %s \
13+
// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir \
14+
// RUN: -transform-preload-library=transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir \
15+
// RUN: -transform-interpreter=entry-point=private_helper \
16+
// RUN: -split-input-file -verify-diagnostics
17+
618
// expected-remark @below {{message}}
719
module {}
820

mlir/unittests/Dialect/Transform/Preload.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
10+
#include "mlir/Dialect/Transform/IR/Utils.h"
1011
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
1112
#include "mlir/IR/AsmState.h"
1213
#include "mlir/IR/DialectRegistry.h"
@@ -67,7 +68,9 @@ TEST(Preload, ContextPreloadConstructedLibrary) {
6768
OwningOpRef<ModuleOp> transformLibrary =
6869
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
6970
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
70-
dialect->registerLibraryModule(std::move(transformLibrary));
71+
LogicalResult diag =
72+
dialect->loadIntoLibraryModule(std::move(transformLibrary));
73+
EXPECT_TRUE(succeeded(diag));
7174

7275
ModuleOp retrievedTransformLibrary =
7376
transform::detail::getPreloadedTransformModule(&context);

0 commit comments

Comments
 (0)