diff --git a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp --- a/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandaloneDialect.cpp @@ -16,8 +16,7 @@ // Standalone dialect. //===----------------------------------------------------------------------===// -StandaloneDialect::StandaloneDialect(mlir::MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void StandaloneDialect::initialize() { addOperations< #define GET_OP_LIST #include "Standalone/StandaloneOps.cpp.inc" diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -26,7 +26,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -26,7 +26,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -75,7 +75,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -75,7 +75,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -75,7 +75,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -76,7 +76,8 @@ /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations< #define GET_OP_LIST #include "toy/Ops.cpp.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -27,7 +27,9 @@ private: friend LLVMType; - std::unique_ptr impl; + // This can't be a unique_ptr because the ctor is generated inline + // in the class definition at the moment. + detail::LLVMDialectImpl *impl; }]; } diff --git a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h --- a/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h +++ b/mlir/include/mlir/Dialect/SDBM/SDBMDialect.h @@ -17,7 +17,8 @@ class SDBMDialect : public Dialect { public: - SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {} + SDBMDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) {} /// Since there are no other virtual methods in this derived class, override /// the destructor so that key methods get defined in the corresponding diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -14,6 +14,7 @@ #define MLIR_IR_DIALECT_H #include "mlir/IR/OperationSupport.h" +#include "mlir/Support/TypeID.h" namespace mlir { class DialectAsmParser; @@ -49,6 +50,9 @@ StringRef getNamespace() const { return name; } + /// Returns the unique identifier that corresponds to this dialect. + TypeID getTypeID() const { return dialectID; } + /// Returns true if this dialect allows for unregistered operations, i.e. /// operations prefixed with the dialect namespace but not registered with /// addOperation. @@ -177,7 +181,7 @@ /// with the namespace followed by '.'. /// Example: /// - "tf" for the TensorFlow ops like "tf.add". - Dialect(StringRef name, MLIRContext *context); + Dialect(StringRef name, MLIRContext *context, TypeID id); /// This method is used by derived classes to add their operations to the set. /// @@ -223,13 +227,13 @@ Dialect(const Dialect &) = delete; void operator=(Dialect &) = delete; - /// Register this dialect object with the specified context. The context - /// takes ownership of the heap allocated dialect. - void registerDialect(MLIRContext *context); - /// The namespace of this dialect. StringRef name; + /// The unique identifier of the derived Op class, this is used in the context + /// to allow registering multiple times the same dialect. + TypeID dialectID; + /// This is the context that owns this Dialect object. MLIRContext *context; @@ -255,7 +259,9 @@ const DialectAllocatorFunction &function); template friend void registerDialect(); + friend class MLIRContext; }; + /// Registers all dialects and hooks from the global registries with the /// specified MLIRContext. /// Note: This method is not thread-safe. @@ -265,12 +271,9 @@ /// global registry by calling registerDialect(); /// Note: This method is not thread-safe. template void registerDialect() { - Dialect::registerDialectAllocator(TypeID::get(), - [](MLIRContext *ctx) { - // Just allocate the dialect, the context - // takes ownership of it. - new ConcreteDialect(ctx); - }); + Dialect::registerDialectAllocator( + TypeID::get(), + [](MLIRContext *ctx) { ctx->getOrCreateDialect(); }); } /// DialectRegistration provides a global initializer that registers a Dialect @@ -291,7 +294,7 @@ template struct isa_impl { static inline bool doit(const ::mlir::Dialect &dialect) { - return T::getDialectNamespace() == dialect.getNamespace(); + return mlir::TypeID::get() == dialect.getTypeID(); } }; } // namespace llvm diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -10,6 +10,7 @@ #define MLIR_IR_MLIRCONTEXT_H #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include #include #include @@ -49,6 +50,18 @@ return static_cast(getRegisteredDialect(T::getDialectNamespace())); } + /// Get (or create) a dialect for the given derived dialect type. The derived + /// type must provide a static 'getDialectNamespace' method. + template + T *getOrCreateDialect() { + return static_cast(getOrCreateDialect( + T::getDialectNamespace(), TypeID::get(), [this]() { + std::unique_ptr dialect(new T(this)); + dialect->dialectID = TypeID::get(); + return dialect; + })); + } + /// Return true if we allow to create operation for unregistered dialects. bool allowsUnregisteredDialects(); @@ -109,6 +122,12 @@ private: const std::unique_ptr impl; + /// Get a dialect for the provided namespace and TypeID: abort the program if + /// a dialect exist for this namespace with different TypeID. Returns a + /// pointer to the dialect owned by the context. + Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID, + function_ref()> ctor); + MLIRContext(const MLIRContext &) = delete; void operator=(const MLIRContext &) = delete; }; diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp --- a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp +++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp @@ -18,8 +18,7 @@ using namespace mlir; -avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void avx512::AVX512Dialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/AVX512/AVX512.cpp.inc" diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -68,8 +68,7 @@ // AffineDialect //===----------------------------------------------------------------------===// -AffineDialect::AffineDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void AffineDialect::initialize() { addOperations(isKernelAttr); } -GPUDialect::GPUDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void GPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/GPUOps.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp @@ -20,8 +20,7 @@ using namespace mlir; -LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void LLVM::LLVMAVX512Dialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1683,9 +1683,8 @@ } // end namespace LLVM } // end namespace mlir -LLVMDialect::LLVMDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context), - impl(new detail::LLVMDialectImpl()) { +void LLVMDialect::initialize() { + impl = new detail::LLVMDialectImpl(); // clang-format off addTypes(); addOperations< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -26,8 +26,7 @@ using namespace mlir; using namespace mlir::omp; -OpenMPDialect::OpenMPDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void OpenMPDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -23,8 +23,7 @@ using namespace mlir::quant; using namespace mlir::quant::detail; -QuantizationDialect::QuantizationDialect(MLIRContext *context) - : Dialect(/*name=*/"quant", context) { +void QuantizationDialect::initialize() { addTypes(); addOperations< diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -53,8 +53,7 @@ // SCFDialect //===----------------------------------------------------------------------===// -SCFDialect::SCFDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void SCFDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/SCF/SCFOps.cpp.inc" diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -112,8 +112,7 @@ // SPIR-V Dialect //===----------------------------------------------------------------------===// -SPIRVDialect::SPIRVDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void SPIRVDialect::initialize() { addTypes(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -59,8 +59,7 @@ return success(); } -ShapeDialect::ShapeDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void ShapeDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -145,8 +145,7 @@ return success(); } -StandardOpsDialect::StandardOpsDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void StandardOpsDialect::initialize() { addOperations()) { addAttributes(); } + static StringRef getDialectNamespace() { return ""; } }; } // end anonymous namespace. @@ -349,7 +351,7 @@ } // Register dialects with this context. - new BuiltinDialect(this); + getOrCreateDialect(); registerAllDialects(this); // Initialize several common attributes and types to avoid the need to lock @@ -446,25 +448,33 @@ : nullptr; } -/// Register this dialect object with the specified context. The context -/// takes ownership of the heap allocated dialect. -void Dialect::registerDialect(MLIRContext *context) { - auto &impl = context->getImpl(); - std::unique_ptr dialect(this); - +/// Get a dialect for the provided namespace and TypeID: abort the program if a +/// dialect exist for this namespace with different TypeID. Returns a pointer to +/// the dialect owned by the context. +Dialect * +MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID, + function_ref()> ctor) { + auto &impl = getImpl(); // Get the correct insertion position sorted by namespace. - auto insertPt = llvm::lower_bound( - impl.dialects, dialect, [](const auto &lhs, const auto &rhs) { - return lhs->getNamespace() < rhs->getNamespace(); - }); + auto insertPt = + llvm::lower_bound(impl.dialects, nullptr, + [&](const std::unique_ptr &lhs, + const std::unique_ptr &rhs) { + if (!lhs) + return dialectNamespace < rhs->getNamespace(); + return lhs->getNamespace() < dialectNamespace; + }); // Abort if dialect with namespace has already been registered. if (insertPt != impl.dialects.end() && - (*insertPt)->getNamespace() == getNamespace()) { - llvm::report_fatal_error("a dialect with namespace '" + getNamespace() + + (*insertPt)->getNamespace() == dialectNamespace) { + if ((*insertPt)->getTypeID() == dialectID) + return insertPt->get(); + llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + "' has already been registered"); } - impl.dialects.insert(insertPt, std::move(dialect)); + auto it = impl.dialects.insert(insertPt, ctor()); + return &**it; } bool MLIRContext::allowsUnregisteredDialects() { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -130,8 +130,7 @@ // TestDialect //===----------------------------------------------------------------------===// -TestDialect::TestDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void TestDialect::initialize() { addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -63,8 +63,14 @@ /// {1}: The dialect namespace. static const char *const dialectDeclBeginStr = R"( class {0} : public ::mlir::Dialect { + explicit {0}(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, + ::mlir::TypeID::get<{0}>()) {{ + initialize(); + } + void initialize(); + friend class ::mlir::MLIRContext; public: - explicit {0}(::mlir::MLIRContext *context); static ::llvm::StringRef getDialectNamespace() { return "{1}"; } )"; diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -14,7 +14,15 @@ namespace { struct TestDialect : public Dialect { - TestDialect(MLIRContext *context) : Dialect(/*name=*/"test", context) {} + static StringRef getDialectNamespace() { return "test"; }; + TestDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) {} +}; +struct AnotherTestDialect : public Dialect { + static StringRef getDialectNamespace() { return "test"; }; + AnotherTestDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, + TypeID::get()) {} }; TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { @@ -22,8 +30,8 @@ // Registering a dialect with the same namespace twice should result in a // failure. - new TestDialect(&context); - ASSERT_DEATH(new TestDialect(&context), ""); + context.getOrCreateDialect(); + ASSERT_DEATH(context.getOrCreateDialect(), ""); } } // end namespace