Skip to content

Commit 23326b9

Browse files
committed
[mlir][spirv] Fix a few issues in ModuleCombiner
- Fixed symbol insertion into `symNameToModuleMap`. Insertion needs to happen whether symbols are renamed or not. - Added check for the VCE triple and avoid dropping it. - Disabled function deduplication. It requires more careful rules. Right now it can remove different functions. - Added tests for symbol rename listener. - And some other code/comment cleanups. Reviewed By: ergawy Differential Revision: https://reviews.llvm.org/D106886
1 parent aa6340c commit 23326b9

File tree

11 files changed

+306
-207
lines changed

11 files changed

+306
-207
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
1414
#define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
1515

16+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1617
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
1718
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1819
#include "mlir/IR/BuiltinOps.h"

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,9 @@ def SPV_ModuleOp : SPV_Op<"module",
467467
let builders = [
468468
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
469469
OpBuilder<(ins "spirv::AddressingModel":$addressing_model,
470-
"spirv::MemoryModel":$memory_model,
471-
CArg<"Optional<StringRef>", "llvm::None">:$name)>
470+
"spirv::MemoryModel":$memory_model,
471+
CArg<"Optional<spirv::VerCapExtAttr>", "llvm::None">:$vce_triple,
472+
CArg<"Optional<StringRef>", "llvm::None">:$name)>
472473
];
473474

474475
// We need to ensure the block inside the region is properly terminated;

mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,54 @@ class OpBuilder;
2222
namespace spirv {
2323
class ModuleOp;
2424

25-
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
25+
/// The listener function to receive symbol renaming events.
26+
///
27+
/// `originalModule` is the input spirv::ModuleOp that contains the renamed
28+
/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol.
29+
/// Note that it's the responsibility of the caller to properly retain the
30+
/// storage underlying the passed StringRefs if the listener callback outlives
31+
/// this function call.
32+
using SymbolRenameListener = function_ref<void(
33+
spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)>;
34+
35+
/// Combines a list of SPIR-V `inputModules` into one. Returns the combined
36+
/// module on success; returns a null module otherwise.
37+
//
38+
/// \param inputModules the list of modules to combine. They won't be modified.
39+
/// \param combinedMdouleBuilder an OpBuilder for building the combined module.
40+
/// \param symbRenameListener a listener that gets called everytime a symbol in
41+
/// one of the input modules is renamed.
42+
///
43+
/// To combine multiple SPIR-V modules, we move all the module-level ops
2644
/// from all the input modules into one big combined module. To that end, the
2745
/// combination process proceeds in 2 phases:
2846
///
29-
/// (1) resolve conflicts between pairs of ops from different modules
30-
/// (2) deduplicate equivalent ops/sub-ops in the merged module.
47+
/// 1. resolve conflicts between pairs of ops from different modules,
48+
/// 2. deduplicate equivalent ops/sub-ops in the merged module.
3149
///
3250
/// For the conflict resolution phase, the following rules are employed to
3351
/// resolve such conflicts:
3452
///
35-
/// - If 2 spv.func's have the same symbol name, then rename one of the
53+
/// - If 2 spv.func's have the same symbol name, then rename one of the
3654
/// functions.
37-
/// - If an spv.func and another op have the same symbol name, then rename the
55+
/// - If an spv.func and another op have the same symbol name, then rename the
3856
/// other symbol.
39-
/// - If none of the 2 conflicting ops are spv.func, then rename either.
57+
/// - If none of the 2 conflicting ops are spv.func, then rename either.
4058
///
4159
/// For deduplication, the following 3 cases are taken into consideration:
4260
///
43-
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
61+
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
4462
/// or the same build_in attribute value, then replace one of them using the
4563
/// other.
46-
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
64+
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
4765
/// replace one of them using the other.
48-
/// - If 2 spv.func's are identical replace one of them using the other.
66+
/// - Deduplicating functions are not supported right now.
4967
///
5068
/// In all cases, the references to the updated symbol (whether renamed or
5169
/// deduplicated) are also updated to reflect the change.
52-
///
53-
/// \param modules the list of modules to combine. Input modules are not
54-
/// modified.
55-
/// \param combinedMdouleBuilder an OpBuilder to be used for
56-
// building up the combined module.
57-
/// \param symbRenameListener a listener that gets called everytime a symbol in
58-
/// one of the input modules is renamed. The arguments
59-
/// passed to the listener are: the input
60-
/// spirv::ModuleOp that contains the renamed symbol,
61-
/// a StringRef to the old symbol name, and a
62-
/// StringRef to the new symbol name. Note that it is
63-
/// the responsibility of the caller to properly
64-
/// retain the storage underlying the passed
65-
/// StringRefs if the listener callback outlives this
66-
/// function call.
67-
///
68-
/// \return the combined module.
69-
OwningOpRef<spirv::ModuleOp>
70-
combine(MutableArrayRef<ModuleOp> modules, OpBuilder &combinedModuleBuilder,
71-
function_ref<void(ModuleOp, StringRef, StringRef)> symbRenameListener);
70+
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
71+
OpBuilder &combinedModuleBuilder,
72+
SymbolRenameListener symRenameListener);
7273
} // namespace spirv
7374
} // namespace mlir
7475

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
310310
// Add a keyword to the module name to avoid symbolic conflict.
311311
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
312312
auto spvModule = rewriter.create<spirv::ModuleOp>(
313-
moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
313+
moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None,
314314
StringRef(spvModuleName));
315315

316316
// Move the region from the module op into the SPIR-V module.

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
25402540
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
25412541
spirv::AddressingModel addressingModel,
25422542
spirv::MemoryModel memoryModel,
2543+
Optional<VerCapExtAttr> vceTriple,
25432544
Optional<StringRef> name) {
25442545
state.addAttribute(
25452546
"addressing_model",
@@ -2548,10 +2549,11 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
25482549
static_cast<int32_t>(memoryModel)));
25492550
OpBuilder::InsertionGuard guard(builder);
25502551
builder.createBlock(state.addRegion());
2551-
if (name) {
2552-
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2553-
builder.getStringAttr(*name));
2554-
}
2552+
if (vceTriple)
2553+
state.addAttribute(getVCETripleAttrName(), *vceTriple);
2554+
if (name)
2555+
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
2556+
builder.getStringAttr(*name));
25552557
}
25562558

25572559
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {

0 commit comments

Comments
 (0)