diff --git a/mlir/include/mlir/Support/TypeID.h b/mlir/include/mlir/Support/TypeID.h --- a/mlir/include/mlir/Support/TypeID.h +++ b/mlir/include/mlir/Support/TypeID.h @@ -137,6 +137,25 @@ } // end namespace mlir +// Declare/define an explicit specialization for TypeID: this forces the +// compiler to emit a strong definition for a class and controls which +// translation unit and shared object will actually have it. +// This can be useful to turn to a link-time failure what would be in other +// circumstances a hard-to-catch runtime bug when a TypeID is hidden in two +// different shared libraries and instances of the same class only gets the same +// TypeID inside a given DSO. +#define DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME) \ + template <> \ + LLVM_EXTERNAL_VISIBILITY mlir::TypeID \ + mlir::detail::TypeIDExported::get(); +#define DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME) \ + template <> \ + LLVM_EXTERNAL_VISIBILITY mlir::TypeID \ + mlir::detail::TypeIDExported::get() { \ + static mlir::TypeID::Storage instance; \ + return mlir::TypeID(&instance); \ + } + namespace llvm { template <> struct DenseMapInfo { static mlir::TypeID getEmptyKey() { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -440,16 +440,24 @@ collectAllDefs(selectedDialect, defRecords, defs); if (defs.empty()) return false; + { + NamespaceEmitter nsEmitter(os, defs.front().getDialect()); - NamespaceEmitter nsEmitter(os, defs.front().getDialect()); + // Declare all the def classes first (in case they reference each other). + for (const AttrOrTypeDef &def : defs) + os << " class " << def.getCppClassName() << ";\n"; - // Declare all the def classes first (in case they reference each other). + // Emit the declarations. + for (const AttrOrTypeDef &def : defs) + emitDefDecl(def); + } + // Emit the TypeID explicit specializations to have a single definition for + // each of these. for (const AttrOrTypeDef &def : defs) - os << " class " << def.getCppClassName() << ";\n"; + if (!def.getDialect().getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() + << "::" << def.getCppClassName() << ")\n"; - // Emit the declarations. - for (const AttrOrTypeDef &def : defs) - emitDefDecl(def); return false; } @@ -934,8 +942,13 @@ IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); emitParsePrintDispatch(defs); - for (const AttrOrTypeDef &def : defs) + for (const AttrOrTypeDef &def : defs) { emitDefDef(def); + // Emit the TypeID explicit specializations to have a single symbol def. + if (!def.getDialect().getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() + << "::" << def.getCppClassName() << ")\n"; + } return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -198,38 +198,44 @@ } // Emit all nested namespaces. - NamespaceEmitter nsEmitter(os, dialect); - - // Emit the start of the decl. - std::string cppName = dialect.getCppClassName(); - os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); - - // Check for any attributes/types registered to this dialect. If there are, - // add the hooks for parsing/printing. - if (!dialectAttrs.empty()) - os << attrParserDecl; - if (!dialectTypes.empty()) - os << typeParserDecl; - - // Add the decls for the various features of the dialect. - if (dialect.hasCanonicalizer()) - os << canonicalizerDecl; - if (dialect.hasConstantMaterializer()) - os << constantMaterializerDecl; - if (dialect.hasOperationAttrVerify()) - os << opAttrVerifierDecl; - if (dialect.hasRegionArgAttrVerify()) - os << regionArgAttrVerifierDecl; - if (dialect.hasRegionResultAttrVerify()) - os << regionResultAttrVerifierDecl; - if (dialect.hasOperationInterfaceFallback()) - os << operationInterfaceFallbackDecl; - if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) - os << *extraDecl; - - // End the dialect decl. - os << "};\n"; + { + NamespaceEmitter nsEmitter(os, dialect); + + // Emit the start of the decl. + std::string cppName = dialect.getCppClassName(); + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), + dependentDialectRegistrations); + + // Check for any attributes/types registered to this dialect. If there are, + // add the hooks for parsing/printing. + if (!dialectAttrs.empty()) + os << attrParserDecl; + if (!dialectTypes.empty()) + os << typeParserDecl; + + // Add the decls for the various features of the dialect. + if (dialect.hasCanonicalizer()) + os << canonicalizerDecl; + if (dialect.hasConstantMaterializer()) + os << constantMaterializerDecl; + if (dialect.hasOperationAttrVerify()) + os << opAttrVerifierDecl; + if (dialect.hasRegionArgAttrVerify()) + os << regionArgAttrVerifierDecl; + if (dialect.hasRegionResultAttrVerify()) + os << regionResultAttrVerifierDecl; + if (dialect.hasOperationInterfaceFallback()) + os << operationInterfaceFallbackDecl; + if (llvm::Optional extraDecl = + dialect.getExtraClassDeclaration()) + os << *extraDecl; + + // End the dialect decl. + os << "};\n"; + } + if (!dialect.getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() + << "::" << dialect.getCppClassName() << ")\n"; } static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, @@ -263,6 +269,11 @@ )"; static void emitDialectDef(Dialect &dialect, raw_ostream &os) { + // Emit the TypeID explicit specializations to have a single symbol def. + if (!dialect.getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() + << "::" << dialect.getCppClassName() << ")\n"; + // Emit all nested namespaces. NamespaceEmitter nsEmitter(os, dialect); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -650,7 +650,6 @@ generateOpFormat(op, opClass); genSideEffectInterfaceMethods(); } - void OpEmitter::emitDecl( const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { @@ -2576,15 +2575,29 @@ emitDecl); for (auto *def : defs) { Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); if (emitDecl) { - os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - OpOperandAdaptorEmitter::emitDecl(op, os); - OpEmitter::emitDecl(op, os, staticVerifierEmitter); + { + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << formatv(opCommentHeader, op.getQualCppClassName(), + "declarations"); + OpOperandAdaptorEmitter::emitDecl(op, os); + OpEmitter::emitDecl(op, os, staticVerifierEmitter); + } + // Emit the TypeID explicit specialization to have a single definition. + if (!op.getCppNamespace().empty()) + os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() + << "::" << op.getCppClassName() << ")\n\n"; } else { - os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - OpOperandAdaptorEmitter::emitDef(op, os); - OpEmitter::emitDef(op, os, staticVerifierEmitter); + { + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); + OpOperandAdaptorEmitter::emitDef(op, os); + OpEmitter::emitDef(op, os, staticVerifierEmitter); + } + // Emit the TypeID explicit specialization to have a single definition. + if (!op.getCppNamespace().empty()) + os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() + << "::" << op.getCppClassName() << ")\n\n"; } } }