diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -76,7 +76,7 @@ if (showDialects) { mlir::MLIRContext context; llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + for (mlir::Dialect *dialect : context.getLoadedDialects()) { llvm::outs() << dialect->getNamespace() << "\n"; } return 0; diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp --- a/mlir/examples/toy/Ch2/toyc.cpp +++ b/mlir/examples/toy/Ch2/toyc.cpp @@ -68,10 +68,9 @@ } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect(); - mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -102,10 +102,10 @@ } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect(); - mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -103,10 +103,10 @@ } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect(); - mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -256,6 +256,10 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -106,10 +106,10 @@ } int dumpMLIR() { - // Register our Dialect with MLIR. - mlir::registerDialect(); - mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + mlir::OwningModuleRef module; llvm::SourceMgr sourceMgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -255,6 +255,10 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -159,6 +159,10 @@ namespace { struct ToyToLLVMLoweringPass : public PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -255,10 +255,10 @@ // If we aren't dumping the AST, then we are compiling with/to MLIR. - // Register our Dialect with MLIR. - mlir::registerDialect(); - mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + mlir::OwningModuleRef module; if (int error = loadAndProcessMLIR(context, module)) return error; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -256,6 +256,10 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } void runOnFunction() final; }; } // end anonymous namespace. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -159,6 +159,10 @@ namespace { struct ToyToLLVMLoweringPass : public PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } void runOnOperation() final; }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -256,10 +256,11 @@ // If we aren't dumping the AST, then we are compiling with/to MLIR. - // Register our Dialect with MLIR. - mlir::registerDialect(); mlir::MLIRContext context; + // Load our Dialect in this MLIR Context. + context.getOrLoadDialect(); + mlir::OwningModuleRef module; if (int error = loadAndProcessMLIR(context, module)) return error; diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -90,6 +90,12 @@ /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); +/** Load all the globally registered dialects in the provided context. + * TODO: remove the concept of globally registered dialect by exposing the + * DialectRegistry. + */ +void mlirContextLoadAllDialects(MlirContext context); + /*============================================================================*/ /* Location API. */ /*============================================================================*/ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -66,6 +66,11 @@ `affine.apply`. }]; let constructor = "mlir::createLowerAffinePass()"; + let dependentDialects = [ + "scf::SCFDialect", + "StandardOpsDialect", + "vector::VectorDialect" + ]; } //===----------------------------------------------------------------------===// @@ -76,6 +81,7 @@ let summary = "Convert the operations from the avx512 dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertAVX512ToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; } //===----------------------------------------------------------------------===// @@ -98,6 +104,7 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; + let dependentDialects = ["NVVM::NVVMDialect"]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -112,6 +119,7 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> { let summary = "Generate ROCDL operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()"; + let dependentDialects = ["ROCDL::ROCDLDialect"]; let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", @@ -126,6 +134,7 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> { let summary = "Convert GPU dialect to SPIR-V dialect"; let constructor = "mlir::createConvertGPUToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -136,6 +145,7 @@ : Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> { let summary = "Convert gpu.launch_func to vulkanLaunch external call"; let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } def ConvertVulkanLaunchFuncToVulkanCalls @@ -143,6 +153,7 @@ let summary = "Convert vulkanLaunch external call to Vulkan runtime external " "calls"; let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -153,6 +164,7 @@ let summary = "Convert the operations from the linalg dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertLinalgToLLVMPass()"; + let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -163,6 +175,7 @@ let summary = "Convert the operations from the linalg dialect into the " "Standard dialect"; let constructor = "mlir::createConvertLinalgToStandardPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -172,6 +185,7 @@ def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> { let summary = "Convert Linalg ops to SPIR-V ops"; let constructor = "mlir::createLinalgToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -182,6 +196,7 @@ let summary = "Convert SCF dialect to Standard dialect, replacing structured" " control flow with a CFG"; let constructor = "mlir::createLowerToCFGPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -191,6 +206,7 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> { let summary = "Convert top-level AffineFor Ops to GPU kernels"; let constructor = "mlir::createAffineForToGPUPass()"; + let dependentDialects = ["gpu::GPUDialect"]; let options = [ Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u", "Number of GPU block dimensions for mapping">, @@ -202,6 +218,7 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> { let summary = "Convert mapped scf.parallel ops to gpu launch operations"; let constructor = "mlir::createParallelLoopToGpuPass()"; + let dependentDialects = ["AffineDialect", "gpu::GPUDialect"]; } //===----------------------------------------------------------------------===// @@ -212,6 +229,7 @@ let summary = "Convert operations from the shape dialect into the standard " "dialect"; let constructor = "mlir::createConvertShapeToStandardPass()"; + let dependentDialects = ["StandardOpsDialect"]; } //===----------------------------------------------------------------------===// @@ -221,6 +239,7 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> { let summary = "Convert operations from the shape dialect to the SCF dialect"; let constructor = "mlir::createConvertShapeToSCFPass()"; + let dependentDialects = ["scf::SCFDialect"]; } //===----------------------------------------------------------------------===// @@ -230,6 +249,7 @@ def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> { let summary = "Convert SPIR-V dialect to LLVM dialect"; let constructor = "mlir::createConvertSPIRVToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// @@ -264,6 +284,7 @@ LLVM IR types. }]; let constructor = "mlir::createLowerToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false", "Use aligned_alloc in place of malloc for heap allocations">, @@ -287,11 +308,13 @@ def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> { let summary = "Legalize standard ops for SPIR-V lowering"; let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> { let summary = "Convert Standard Ops to SPIR-V dialect"; let constructor = "mlir::createConvertStandardToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; } //===----------------------------------------------------------------------===// @@ -302,6 +325,7 @@ let summary = "Lower the operations from the vector dialect into the SCF " "dialect"; let constructor = "mlir::createConvertVectorToSCFPass()"; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; let options = [ Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false", "Perform full unrolling when converting vector transfers to SCF">, @@ -316,6 +340,7 @@ let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", @@ -331,6 +356,7 @@ let summary = "Lower the operations from the vector dialect into the ROCDL " "dialect"; let constructor = "mlir::createConvertVectorToROCDLPass()"; + let dependentDialects = ["ROCDL::ROCDLDialect"]; } #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -94,6 +94,7 @@ def AffineVectorize : FunctionPass<"affine-super-vectorize"> { let summary = "Vectorize to a target independent n-D vector abstraction"; let constructor = "mlir::createSuperVectorizePass()"; + let dependentDialects = ["vector::VectorDialect"]; let options = [ ListOption<"vectorSizes", "virtual-vector-size", "int64_t", "Specify an n-D virtual vector size for vectorization", diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -19,6 +19,11 @@ def LLVM_Dialect : Dialect { let name = "llvm"; let cppNamespace = "LLVM"; + + /// FIXME: at the moment this is a dependency of the translation to LLVM IR, + /// not really one of this dialect per-se. + let dependentDialects = [ "omp::OpenMPDialect" ]; + let hasRegionArgAttrVerify = 1; let extraClassDeclaration = [{ ~LLVMDialect(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -23,6 +23,7 @@ def NVVM_Dialect : Dialect { let name = "nvvm"; let cppNamespace = "NVVM"; + let dependentDialects = [ "LLVM::LLVMDialect" ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -22,6 +22,7 @@ #ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ #define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -23,6 +23,7 @@ def ROCDL_Dialect : Dialect { let name = "rocdl"; let cppNamespace = "ROCDL"; + let dependentDialects = [ "LLVM::LLVMDialect" ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -30,17 +30,20 @@ def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; + let dependentDialects = ["AffineDialect"]; } def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; + let dependentDialects = ["AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; + let dependentDialects = ["scf::SCFDialect", "AffineDialect"]; } def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> { @@ -54,6 +57,7 @@ let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; } def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> { @@ -70,6 +74,9 @@ def LinalgTiling : FunctionPass<"linalg-tile"> { let summary = "Tile operations in the linalg dialect"; let constructor = "mlir::createLinalgTilingPass()"; + let dependentDialects = [ + "AffineDialect", "scf::SCFDialect" + ]; let options = [ ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Test generation of dynamic promoted buffers", @@ -86,6 +93,7 @@ "Test generation of dynamic promoted buffers", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; + let dependentDialects = ["AffineDialect", "scf::SCFDialect"]; } #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -36,6 +36,7 @@ "Factors to tile parallel loops by", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; + let dependentDialects = ["AffineDialect"]; } #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -16,6 +16,8 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/Support/TypeID.h" +#include + namespace mlir { class DialectAsmParser; class DialectAsmPrinter; @@ -23,7 +25,7 @@ class OpBuilder; class Type; -using DialectAllocatorFunction = std::function; +using DialectAllocatorFunction = std::function; /// Dialects are groups of MLIR operations and behavior associated with the /// entire group. For example, hooks into other systems for constant folding, @@ -212,30 +214,82 @@ /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; - /// Registers a specific dialect creation function with the global registry. - /// Used through the registerDialect template. - /// Registrations are deduplicated by dialect TypeID and only the first - /// registration will be used. - static void - registerDialectAllocator(TypeID typeID, - const DialectAllocatorFunction &function); - template friend void registerDialect(); friend class MLIRContext; }; -/// Registers all dialects and hooks from the global registries with the -/// specified MLIRContext. +/// The DialectRegistry maps a dialect namespace to a constructor for the +/// matching dialect. +/// This allows for decoupling the list of dialects "available" from the +/// dialects loaded in the Context. The parser in particular will lazily load +/// dialects in in the Context as operations are encountered. +class DialectRegistry { + using MapTy = + std::map>; + +public: + template + void insert() { + insert(TypeID::get(), + ConcreteDialect::getDialectNamespace(), + static_cast(([](MLIRContext *ctx) { + // Just allocate the dialect, the context + // takes ownership of it. + return ctx->getOrLoadDialect(); + }))); + } + + /// Add a new dialect constructor to the registry. + void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); + + /// Load a dialect for this namespace in the provided context. + Dialect *loadByName(StringRef name, MLIRContext *context); + + // Register all dialects available in the current registry with the registry + // in the provided context. + void appendTo(DialectRegistry &destination) { + for (const auto &name_and_registration_it : registry) + destination.insert(name_and_registration_it.second.first, + name_and_registration_it.first, + name_and_registration_it.second.second); + } + // Load all dialects available in the registry in the provided context. + void loadAll(MLIRContext *context) { + for (const auto &name_and_registration_it : registry) + name_and_registration_it.second.second(context); + } + + MapTy::const_iterator begin() const { return registry.begin(); } + MapTy::const_iterator end() const { return registry.end(); } + +private: + MapTy registry; +}; + +/// Deprecated: this provides a global registry for convenience, while we're +/// transitionning the registration mechanism to a stateless approach. +DialectRegistry &getGlobalDialectRegistry(); + +/// Registers all dialects from the global registries with the +/// specified MLIRContext. This won't load the dialects in the context, +/// but only make them available for lazy loading by name. /// Note: This method is not thread-safe. -void registerAllDialects(MLIRContext *context); +inline void registerAllDialects(MLIRContext *context) { + getGlobalDialectRegistry().appendTo(context->getDialectRegistry()); +} + +/// Register and return the dialect with the given namespace in the provided +/// context. Returns nullptr is there is no constructor registered for this +/// dialect. +inline Dialect *registerDialect(StringRef name, MLIRContext *context) { + return getGlobalDialectRegistry().loadByName(name, context); +} /// Utility to register a dialect. Client can register their dialect with the /// global registry by calling registerDialect(); /// Note: This method is not thread-safe. template void registerDialect() { - Dialect::registerDialectAllocator( - TypeID::get(), - [](MLIRContext *ctx) { ctx->getOrCreateDialect(); }); + getGlobalDialectRegistry().insert(); } /// DialectRegistration provides a global initializer that registers a Dialect diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -428,7 +428,7 @@ if (!attr.first.strref().contains('.')) return funcOp.emitOpError("arguments may only have dialect attributes"); auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, /*argIndex=*/i, attr))) return failure(); @@ -444,7 +444,7 @@ if (!attr.first.strref().contains('.')) return funcOp.emitOpError("results may only have dialect attributes"); auto dialectNamePair = attr.first.strref().split('.'); - if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) { + if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) { if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, /*resultIndex=*/i, attr))) diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -19,10 +19,12 @@ class AbstractOperation; class DiagnosticEngine; class Dialect; +class DialectRegistry; class InFlightDiagnostic; class Location; class MLIRContextImpl; class StorageUniquer; +DialectRegistry &getGlobalDialectRegistry(); /// MLIRContext is the top-level object for a collection of MLIR modules. It /// holds immortal uniqued objects like types, and the tables used to unique @@ -34,34 +36,52 @@ /// class MLIRContext { public: - explicit MLIRContext(); + explicit MLIRContext(bool loadAllDialects = true); ~MLIRContext(); - /// Return information about all registered IR dialects. - std::vector getRegisteredDialects(); + /// Return information about all IR dialects loaded in the context. + std::vector getLoadedDialects(); + + /// Return the dialect registry associated with this context. + DialectRegistry &getDialectRegistry(); + + /// Return information about all available dialects in the registry in this + /// context. + std::vector getAvailableDialects(); /// Get a registered IR dialect with the given namespace. If an exact match is /// not found, then return nullptr. - Dialect *getRegisteredDialect(StringRef name); + Dialect *getLoadedDialect(StringRef name); /// Get a registered IR dialect for the given derived dialect type. The /// derived type must provide a static 'getDialectNamespace' method. - template T *getRegisteredDialect() { - return static_cast(getRegisteredDialect(T::getDialectNamespace())); + template + T *getLoadedDialect() { + return static_cast(getLoadedDialect(T::getDialectNamespace())); } /// Get (or create) a dialect for the given derived dialect type. The derived /// type must provide a static 'getDialectNamespace' method. template - T *getOrCreateDialect() { - return static_cast(getOrCreateDialect( - T::getDialectNamespace(), TypeID::get(), [this]() { + T *getOrLoadDialect() { + return static_cast( + getOrLoadDialect(T::getDialectNamespace(), TypeID::get(), [this]() { std::unique_ptr dialect(new T(this)); - dialect->dialectID = TypeID::get(); return dialect; })); } + /// Deprecated: load all globally registered dialects into this context. + /// This method will be removed soon, it can be used temporarily as we're + /// phasing out the global registry. + void loadAllGloballyRegisteredDialects(); + + /// Get (or create) a dialect for the given derived dialect name. + /// The dialect will be loaded from the registry if no dialect is found. + /// If no dialect is loaded for this name and none is available in the + /// registry, returns nullptr. + Dialect *getOrLoadDialect(StringRef name); + /// Return true if we allow to create operation for unregistered dialects. bool allowsUnregisteredDialects(); @@ -123,10 +143,12 @@ const std::unique_ptr impl; /// Get a dialect for the provided namespace and TypeID: abort the program if - /// a dialect exist for this namespace with different TypeID. Returns a - /// pointer to the dialect owned by the context. - Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID, - function_ref()> ctor); + /// a dialect exist for this namespace with different TypeID. If a dialect has + /// not been loaded for this namespace/TypeID yet, use the provided ctor to + /// create one on the fly and load it. Returns a pointer to the dialect owned + /// by the context. + Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, + function_ref()> ctor); MLIRContext(const MLIRContext &) = delete; void operator=(const MLIRContext &) = delete; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -244,6 +244,11 @@ // The description of the dialect. string description = ?; + // A list of dialects this dialect will load on construction as dependencies. + // These are dialects that this dialect may involved in canonicalization + // pattern or interfaces. + list dependentDialects = []; + // The C++ namespace that ops of this dialect should be placed into. // // By default, uses the name of the dialect as the only namespace. To avoid diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -35,29 +35,32 @@ namespace mlir { +// Add all the MLIR dialects to the provided registry. +inline void registerAllDialects(DialectRegistry ®istry) { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); +} + // This function should be called before creating any MLIRContext if one expect // all the possible dialects to be made available to the context automatically. inline void registerAllDialects() { - static bool init_once = []() { - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - registerDialect(); - return true; - }(); + static bool init_once = + ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true); (void)init_once; } } // namespace mlir diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -9,6 +9,7 @@ #ifndef MLIR_PASS_PASS_H #define MLIR_PASS_PASS_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/Pass/AnalysisManager.h" #include "mlir/Pass/PassRegistry.h" @@ -57,6 +58,13 @@ /// Returns the derived pass name. virtual StringRef getName() const = 0; + /// Register dependent dialects for the current pass. + /// A pass is expected to register the dialects it will create operations for, + /// other than dialect that exists in the input. For example, a pass that + /// converts from Linalg to Affine would register the Affine dialect but does + /// not need to register Linalg. + virtual void getDependentDialects(DialectRegistry ®istry) const {} + /// Returns the command line argument used when registering this pass. Return /// an empty string if one does not exist. virtual StringRef getArgument() const { diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -78,6 +78,9 @@ // A C++ constructor call to create an instance of this pass. code constructor = [{}]; + // A list of dialects this pass may produce operations in. + list dependentDialects = []; + // A set of options provided by this pass. list