Skip to content

Commit 660a569

Browse files
committed
Emit strong definition for TypeID storage in Op/Type/Attributes definition
By making an explicit template specialization for the TypeID provided by these classes, the compiler will not emit an inline weak definition and rely on the linker to unique it. Instead a single definition will be emitted in the C++ file alongside the implementation for these classes. That will turn into a linker error what is now a hard-to-debug runtime behavior where instances of the same class may be using a different TypeID inside of different DSOs. Differential Revision: https://reviews.llvm.org/D105903
1 parent 6cba963 commit 660a569

File tree

4 files changed

+103
-47
lines changed

4 files changed

+103
-47
lines changed

mlir/include/mlir/Support/TypeID.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,25 @@ TypeID TypeID::get() {
137137

138138
} // end namespace mlir
139139

140+
// Declare/define an explicit specialization for TypeID: this forces the
141+
// compiler to emit a strong definition for a class and controls which
142+
// translation unit and shared object will actually have it.
143+
// This can be useful to turn to a link-time failure what would be in other
144+
// circumstances a hard-to-catch runtime bug when a TypeID is hidden in two
145+
// different shared libraries and instances of the same class only gets the same
146+
// TypeID inside a given DSO.
147+
#define DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME) \
148+
template <> \
149+
LLVM_EXTERNAL_VISIBILITY mlir::TypeID \
150+
mlir::detail::TypeIDExported::get<CLASS_NAME>();
151+
#define DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME) \
152+
template <> \
153+
LLVM_EXTERNAL_VISIBILITY mlir::TypeID \
154+
mlir::detail::TypeIDExported::get<CLASS_NAME>() { \
155+
static mlir::TypeID::Storage instance; \
156+
return mlir::TypeID(&instance); \
157+
}
158+
140159
namespace llvm {
141160
template <> struct DenseMapInfo<mlir::TypeID> {
142161
static mlir::TypeID getEmptyKey() {

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -440,16 +440,24 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) {
440440
collectAllDefs(selectedDialect, defRecords, defs);
441441
if (defs.empty())
442442
return false;
443+
{
444+
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
443445

444-
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
446+
// Declare all the def classes first (in case they reference each other).
447+
for (const AttrOrTypeDef &def : defs)
448+
os << " class " << def.getCppClassName() << ";\n";
445449

446-
// Declare all the def classes first (in case they reference each other).
450+
// Emit the declarations.
451+
for (const AttrOrTypeDef &def : defs)
452+
emitDefDecl(def);
453+
}
454+
// Emit the TypeID explicit specializations to have a single definition for
455+
// each of these.
447456
for (const AttrOrTypeDef &def : defs)
448-
os << " class " << def.getCppClassName() << ";\n";
457+
if (!def.getDialect().getCppNamespace().empty())
458+
os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
459+
<< "::" << def.getCppClassName() << ")\n";
449460

450-
// Emit the declarations.
451-
for (const AttrOrTypeDef &def : defs)
452-
emitDefDecl(def);
453461
return false;
454462
}
455463

@@ -934,8 +942,13 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
934942

935943
IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os);
936944
emitParsePrintDispatch(defs);
937-
for (const AttrOrTypeDef &def : defs)
945+
for (const AttrOrTypeDef &def : defs) {
938946
emitDefDef(def);
947+
// Emit the TypeID explicit specializations to have a single symbol def.
948+
if (!def.getDialect().getCppNamespace().empty())
949+
os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace()
950+
<< "::" << def.getCppClassName() << ")\n";
951+
}
939952

940953
return false;
941954
}

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -198,38 +198,44 @@ static void emitDialectDecl(Dialect &dialect,
198198
}
199199

200200
// Emit all nested namespaces.
201-
NamespaceEmitter nsEmitter(os, dialect);
202-
203-
// Emit the start of the decl.
204-
std::string cppName = dialect.getCppClassName();
205-
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
206-
dependentDialectRegistrations);
207-
208-
// Check for any attributes/types registered to this dialect. If there are,
209-
// add the hooks for parsing/printing.
210-
if (!dialectAttrs.empty())
211-
os << attrParserDecl;
212-
if (!dialectTypes.empty())
213-
os << typeParserDecl;
214-
215-
// Add the decls for the various features of the dialect.
216-
if (dialect.hasCanonicalizer())
217-
os << canonicalizerDecl;
218-
if (dialect.hasConstantMaterializer())
219-
os << constantMaterializerDecl;
220-
if (dialect.hasOperationAttrVerify())
221-
os << opAttrVerifierDecl;
222-
if (dialect.hasRegionArgAttrVerify())
223-
os << regionArgAttrVerifierDecl;
224-
if (dialect.hasRegionResultAttrVerify())
225-
os << regionResultAttrVerifierDecl;
226-
if (dialect.hasOperationInterfaceFallback())
227-
os << operationInterfaceFallbackDecl;
228-
if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
229-
os << *extraDecl;
230-
231-
// End the dialect decl.
232-
os << "};\n";
201+
{
202+
NamespaceEmitter nsEmitter(os, dialect);
203+
204+
// Emit the start of the decl.
205+
std::string cppName = dialect.getCppClassName();
206+
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
207+
dependentDialectRegistrations);
208+
209+
// Check for any attributes/types registered to this dialect. If there are,
210+
// add the hooks for parsing/printing.
211+
if (!dialectAttrs.empty())
212+
os << attrParserDecl;
213+
if (!dialectTypes.empty())
214+
os << typeParserDecl;
215+
216+
// Add the decls for the various features of the dialect.
217+
if (dialect.hasCanonicalizer())
218+
os << canonicalizerDecl;
219+
if (dialect.hasConstantMaterializer())
220+
os << constantMaterializerDecl;
221+
if (dialect.hasOperationAttrVerify())
222+
os << opAttrVerifierDecl;
223+
if (dialect.hasRegionArgAttrVerify())
224+
os << regionArgAttrVerifierDecl;
225+
if (dialect.hasRegionResultAttrVerify())
226+
os << regionResultAttrVerifierDecl;
227+
if (dialect.hasOperationInterfaceFallback())
228+
os << operationInterfaceFallbackDecl;
229+
if (llvm::Optional<StringRef> extraDecl =
230+
dialect.getExtraClassDeclaration())
231+
os << *extraDecl;
232+
233+
// End the dialect decl.
234+
os << "};\n";
235+
}
236+
if (!dialect.getCppNamespace().empty())
237+
os << "DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
238+
<< "::" << dialect.getCppClassName() << ")\n";
233239
}
234240

235241
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
@@ -263,6 +269,11 @@ static const char *const dialectDestructorStr = R"(
263269
)";
264270

265271
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
272+
// Emit the TypeID explicit specializations to have a single symbol def.
273+
if (!dialect.getCppNamespace().empty())
274+
os << "DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
275+
<< "::" << dialect.getCppClassName() << ")\n";
276+
266277
// Emit all nested namespaces.
267278
NamespaceEmitter nsEmitter(os, dialect);
268279

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,6 @@ OpEmitter::OpEmitter(const Operator &op,
650650
generateOpFormat(op, opClass);
651651
genSideEffectInterfaceMethods();
652652
}
653-
654653
void OpEmitter::emitDecl(
655654
const Operator &op, raw_ostream &os,
656655
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
@@ -2576,15 +2575,29 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
25762575
emitDecl);
25772576
for (auto *def : defs) {
25782577
Operator op(*def);
2579-
NamespaceEmitter emitter(os, op.getCppNamespace());
25802578
if (emitDecl) {
2581-
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2582-
OpOperandAdaptorEmitter::emitDecl(op, os);
2583-
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2579+
{
2580+
NamespaceEmitter emitter(os, op.getCppNamespace());
2581+
os << formatv(opCommentHeader, op.getQualCppClassName(),
2582+
"declarations");
2583+
OpOperandAdaptorEmitter::emitDecl(op, os);
2584+
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2585+
}
2586+
// Emit the TypeID explicit specialization to have a single definition.
2587+
if (!op.getCppNamespace().empty())
2588+
os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2589+
<< "::" << op.getCppClassName() << ")\n\n";
25842590
} else {
2585-
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2586-
OpOperandAdaptorEmitter::emitDef(op, os);
2587-
OpEmitter::emitDef(op, os, staticVerifierEmitter);
2591+
{
2592+
NamespaceEmitter emitter(os, op.getCppNamespace());
2593+
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2594+
OpOperandAdaptorEmitter::emitDef(op, os);
2595+
OpEmitter::emitDef(op, os, staticVerifierEmitter);
2596+
}
2597+
// Emit the TypeID explicit specialization to have a single definition.
2598+
if (!op.getCppNamespace().empty())
2599+
os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
2600+
<< "::" << op.getCppClassName() << ")\n\n";
25882601
}
25892602
}
25902603
}

0 commit comments

Comments
 (0)