Skip to content

Commit 2dd9e43

Browse files
committed
[spirv] Use owning module ref to avoid leaks and fix ASAN tests
Differential Revision: https://reviews.llvm.org/D83982
1 parent cc1b9b6 commit 2dd9e43

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,15 @@ class Deserializer {
103103
LogicalResult deserialize();
104104

105105
/// Collects the final SPIR-V ModuleOp.
106-
Optional<spirv::ModuleOp> collect();
106+
spirv::OwningSPIRVModuleRef collect();
107107

108108
private:
109109
//===--------------------------------------------------------------------===//
110110
// Module structure
111111
//===--------------------------------------------------------------------===//
112112

113113
/// Initializes the `module` ModuleOp in this deserializer instance.
114-
spirv::ModuleOp createModuleOp();
114+
spirv::OwningSPIRVModuleRef createModuleOp();
115115

116116
/// Processes SPIR-V module header in `binary`.
117117
LogicalResult processHeader();
@@ -425,7 +425,7 @@ class Deserializer {
425425
Location unknownLoc;
426426

427427
/// The SPIR-V ModuleOp.
428-
Optional<spirv::ModuleOp> module;
428+
spirv::OwningSPIRVModuleRef module;
429429

430430
/// The current function under construction.
431431
Optional<spirv::FuncOp> curFunction;
@@ -556,13 +556,15 @@ LogicalResult Deserializer::deserialize() {
556556
return success();
557557
}
558558

559-
Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
559+
spirv::OwningSPIRVModuleRef Deserializer::collect() {
560+
return std::move(module);
561+
}
560562

561563
//===----------------------------------------------------------------------===//
562564
// Module structure
563565
//===----------------------------------------------------------------------===//
564566

565-
spirv::ModuleOp Deserializer::createModuleOp() {
567+
spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
566568
OpBuilder builder(context);
567569
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
568570
spirv::ModuleOp::build(builder, state);
@@ -1912,10 +1914,10 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
19121914
// Go through all ops and remap the operands.
19131915
auto remapOperands = [&](Operation *op) {
19141916
for (auto &operand : op->getOpOperands())
1915-
if (auto mappedOp = mapper.lookupOrNull(operand.get()))
1917+
if (Value mappedOp = mapper.lookupOrNull(operand.get()))
19161918
operand.set(mappedOp);
19171919
for (auto &succOp : op->getBlockOperands())
1918-
if (auto mappedOp = mapper.lookupOrNull(succOp.get()))
1920+
if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
19191921
succOp.set(mappedOp);
19201922
};
19211923
for (auto &block : body) {
@@ -2354,7 +2356,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
23542356
return emitError(unknownLoc,
23552357
"missing Execution Model specification in OpEntryPoint");
23562358
}
2357-
auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
2359+
auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
23582360
if (wordIndex >= words.size()) {
23592361
return emitError(unknownLoc, "missing <id> in OpEntryPoint");
23602362
}
@@ -2382,7 +2384,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
23822384
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
23832385
wordIndex++;
23842386
}
2385-
opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
2387+
opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
23862388
opBuilder.getSymbolRefAttr(fnName),
23872389
opBuilder.getArrayAttr(interface));
23882390
return success();
@@ -2594,5 +2596,5 @@ spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
25942596
if (failed(deserializer.deserialize()))
25952597
return nullptr;
25962598

2597-
return deserializer.collect().getValueOr(nullptr);
2599+
return deserializer.collect();
25982600
}

mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
1616
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
1717
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18+
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
1819
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
1920
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
2021
#include "mlir/IR/Builders.h"
@@ -56,15 +57,15 @@ class SerializationTest : public ::testing::Test {
5657
}
5758

5859
Type getFloatStructType() {
59-
OpBuilder opBuilder(module.body());
60+
OpBuilder opBuilder(module->body());
6061
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
6162
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
6263
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
6364
return structType;
6465
}
6566

6667
void addGlobalVar(Type type, llvm::StringRef name) {
67-
OpBuilder opBuilder(module.body());
68+
OpBuilder opBuilder(module->body());
6869
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
6970
opBuilder.create<spirv::GlobalVariableOp>(
7071
UnknownLoc::get(&context), TypeAttr::get(ptrType),
@@ -98,7 +99,7 @@ class SerializationTest : public ::testing::Test {
9899

99100
protected:
100101
MLIRContext context;
101-
spirv::ModuleOp module;
102+
spirv::OwningSPIRVModuleRef module;
102103
SmallVector<uint32_t, 0> binary;
103104
};
104105

@@ -109,7 +110,7 @@ class SerializationTest : public ::testing::Test {
109110
TEST_F(SerializationTest, BlockDecorationTest) {
110111
auto structType = getFloatStructType();
111112
addGlobalVar(structType, "var0");
112-
ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
113+
ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
113114
auto hasBlockDecoration = [](spirv::Opcode opcode,
114115
ArrayRef<uint32_t> operands) -> bool {
115116
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)

0 commit comments

Comments
 (0)