Skip to content

Commit 27324f2

Browse files
ergawyantiagainst
authored andcommitted
[MLIR][SPIRV] Start module combiner.
This commit adds a new library that merges/combines a number of spv modules into a combined one. The library has a single entry point: combine(...). To combine a number of MLIR spv modules, we move all the module-level ops from all the input modules into one big combined module. To that end, the combination process can proceed in 2 phases: (1) resolving conflicts between pairs of ops from different modules (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) This patch implements only the first phase. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D90477
1 parent 13a56ca commit 27324f2

File tree

10 files changed

+1042
-0
lines changed

10 files changed

+1042
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the entry point to the SPIR-V module combiner library.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
14+
#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
15+
16+
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
17+
#include "llvm/ADT/ArrayRef.h"
18+
#include "llvm/ADT/SmallVector.h"
19+
20+
namespace mlir {
21+
class OpBuilder;
22+
23+
namespace spirv {
24+
class ModuleOp;
25+
26+
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
27+
/// from all the input modules into one big combined module. To that end, the
28+
/// combination process proceeds in 2 phases:
29+
///
30+
/// (1) resolve conflicts between pairs of ops from different modules
31+
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
32+
///
33+
/// For the conflict resolution phase, the following rules are employed to
34+
/// resolve such conflicts:
35+
///
36+
/// - If 2 spv.func's have the same symbol name, then rename one of the
37+
/// functions.
38+
/// - If an spv.func and another op have the same symbol name, then rename the
39+
/// other symbol.
40+
/// - If none of the 2 conflicting ops are spv.func, then rename either.
41+
///
42+
/// In all cases, the references to the updated symbol are also updated to
43+
/// reflect the change.
44+
///
45+
/// \param modules the list of modules to combine. Input modules are not
46+
/// modified.
47+
/// \param combinedMdouleBuilder an OpBuilder to be used for
48+
/// building up the combined module.
49+
/// \param symbRenameListener a listener that gets called everytime a symbol in
50+
/// one of the input modules is renamed. The arguments
51+
/// passed to the listener are: the input
52+
/// spirv::ModuleOp that contains the renamed symbol,
53+
/// a StringRef to the old symbol name, and a
54+
/// StringRef to the new symbol name. Note that it is
55+
/// the responsibility of the caller to properly
56+
/// retain the storage underlying the passed
57+
/// StringRefs if the listener callback outlives this
58+
/// function call.
59+
///
60+
/// \return the combined module.
61+
OwningSPIRVModuleRef
62+
combine(llvm::MutableArrayRef<ModuleOp> modules,
63+
OpBuilder &combinedModuleBuilder,
64+
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
65+
symbRenameListener);
66+
} // namespace spirv
67+
} // namespace mlir
68+
69+
#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_

mlir/lib/Dialect/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
3434
MLIRTransforms
3535
)
3636

37+
add_subdirectory(Linking)
3738
add_subdirectory(Serialization)
3839
add_subdirectory(Transforms)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(ModuleCombiner)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_mlir_dialect_library(MLIRSPIRVModuleCombiner
2+
ModuleCombiner.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
6+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the the SPIR-V module combiner library.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
14+
15+
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
16+
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/SymbolTable.h"
18+
#include "llvm/ADT/ArrayRef.h"
19+
#include "llvm/ADT/StringExtras.h"
20+
21+
using namespace mlir;
22+
23+
static constexpr unsigned maxFreeID = 1 << 20;
24+
25+
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
26+
spirv::ModuleOp combinedModule) {
27+
SmallString<64> newSymName(oldSymName);
28+
newSymName.push_back('_');
29+
30+
while (lastUsedID < maxFreeID) {
31+
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
32+
33+
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
34+
newSymName += llvm::utostr(lastUsedID);
35+
break;
36+
}
37+
}
38+
39+
return newSymName;
40+
}
41+
42+
/// Check if a symbol with the same name as op already exists in source. If so,
43+
/// rename op and update all its references in target.
44+
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
45+
spirv::ModuleOp target,
46+
spirv::ModuleOp source,
47+
unsigned &lastUsedID) {
48+
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
49+
return success();
50+
51+
StringRef oldSymName = op.getName();
52+
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
53+
54+
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
55+
return op.emitError("unable to update all symbol uses for ")
56+
<< oldSymName << " to " << newSymName;
57+
58+
SymbolTable::setSymbolName(op, newSymName);
59+
return success();
60+
}
61+
62+
namespace mlir {
63+
namespace spirv {
64+
65+
// TODO Properly test symbol rename listener mechanism.
66+
67+
OwningSPIRVModuleRef
68+
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
69+
OpBuilder &combinedModuleBuilder,
70+
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
71+
symRenameListener) {
72+
unsigned lastUsedID = 0;
73+
74+
if (modules.empty())
75+
return nullptr;
76+
77+
auto addressingModel = modules[0].addressing_model();
78+
auto memoryModel = modules[0].memory_model();
79+
80+
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
81+
modules[0].getLoc(), addressingModel, memoryModel);
82+
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
83+
84+
// In some cases, a symbol in the (current state of the) combined module is
85+
// renamed in order to maintain the conflicting symbol in the input module
86+
// being merged. For example, if the conflict is between a global variable in
87+
// the current combined module and a function in the input module, the global
88+
// varaible is renamed. In order to notify listeners of the symbol updates in
89+
// such cases, we need to keep track of the module from which the renamed
90+
// symbol in the combined module originated. This map keeps such information.
91+
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
92+
93+
for (auto module : modules) {
94+
if (module.addressing_model() != addressingModel ||
95+
module.memory_model() != memoryModel) {
96+
module.emitError(
97+
"input modules differ in addressing model and/or memory model");
98+
return nullptr;
99+
}
100+
101+
spirv::ModuleOp moduleClone = module.clone();
102+
103+
// In the combined module, rename all symbols that conflict with symbols
104+
// from the current input module. This renmaing applies to all ops except
105+
// for spv.funcs. This way, if the conflicting op in the input module is
106+
// non-spv.func, we rename that symbol instead and maintain the spv.func in
107+
// the combined module name as it is.
108+
for (auto &op : combinedModule.getBlock().without_terminator()) {
109+
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
110+
StringRef oldSymName = symbolOp.getName();
111+
112+
if (!isa<FuncOp>(op) &&
113+
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
114+
lastUsedID)))
115+
return nullptr;
116+
117+
StringRef newSymName = symbolOp.getName();
118+
119+
if (symRenameListener && oldSymName != newSymName) {
120+
spirv::ModuleOp originalModule =
121+
symNameToModuleMap.lookup(oldSymName);
122+
123+
if (!originalModule) {
124+
module.emitError("unable to find original ModuleOp for symbol ")
125+
<< oldSymName;
126+
return nullptr;
127+
}
128+
129+
symRenameListener(originalModule, oldSymName, newSymName);
130+
131+
// Since the symbol name is updated, there is no need to maintain the
132+
// entry that assocaites the old symbol name with the original module.
133+
symNameToModuleMap.erase(oldSymName);
134+
// Instead, add a new entry to map the new symbol name to the original
135+
// module in case it gets renamed again later.
136+
symNameToModuleMap[newSymName] = originalModule;
137+
}
138+
}
139+
}
140+
141+
// In the current input module, rename all symbols that conflict with
142+
// symbols from the combined module. This includes renaming spv.funcs.
143+
for (auto &op : moduleClone.getBlock().without_terminator()) {
144+
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
145+
StringRef oldSymName = symbolOp.getName();
146+
147+
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
148+
lastUsedID)))
149+
return nullptr;
150+
151+
StringRef newSymName = symbolOp.getName();
152+
153+
if (symRenameListener && oldSymName != newSymName) {
154+
symRenameListener(module, oldSymName, newSymName);
155+
156+
// Insert the module associated with the symbol name.
157+
auto emplaceResult =
158+
symNameToModuleMap.try_emplace(symbolOp.getName(), module);
159+
160+
// If an entry with the same symbol name is already present, this must
161+
// be a problem with the implementation, specially clean-up of the map
162+
// while iterating over the combined module above.
163+
if (!emplaceResult.second) {
164+
module.emitError("did not expect to find an entry for symbol ")
165+
<< symbolOp.getName();
166+
return nullptr;
167+
}
168+
}
169+
}
170+
}
171+
172+
// Clone all the module's ops to the combined module.
173+
for (auto &op : moduleClone.getBlock().without_terminator())
174+
combinedModuleBuilder.insert(op.clone());
175+
}
176+
177+
return combinedModule;
178+
}
179+
180+
} // namespace spirv
181+
} // namespace mlir
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK: module {
4+
// CHECK-NEXT: spv.module Logical GLSL450 {
5+
// CHECK-NEXT: spv.specConstant @m1_sc
6+
// CHECK-NEXT: spv.specConstant @m2_sc
7+
// CHECK-NEXT: spv.func @variable_init_spec_constant
8+
// CHECK-NEXT: spv._reference_of @m2_sc
9+
// CHECK-NEXT: spv.Variable init
10+
// CHECK-NEXT: spv.Return
11+
// CHECK-NEXT: }
12+
// CHECK-NEXT: }
13+
// CHECK-NEXT: }
14+
15+
module {
16+
spv.module Logical GLSL450 {
17+
spv.specConstant @m1_sc = 42.42 : f32
18+
}
19+
20+
spv.module Logical GLSL450 {
21+
spv.specConstant @m2_sc = 42 : i32
22+
spv.func @variable_init_spec_constant() -> () "None" {
23+
%0 = spv._reference_of @m2_sc : i32
24+
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
25+
spv.Return
26+
}
27+
}
28+
}
29+
30+
// -----
31+
32+
module {
33+
spv.module Physical64 GLSL450 {
34+
}
35+
36+
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
37+
spv.module Logical GLSL450 {
38+
}
39+
}
40+
41+
// -----
42+
43+
module {
44+
spv.module Logical Simple {
45+
}
46+
47+
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
48+
spv.module Logical GLSL450 {
49+
}
50+
}

0 commit comments

Comments
 (0)