diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h --- a/flang/include/flang/Optimizer/Support/InitFIR.h +++ b/flang/include/flang/Optimizer/Support/InitFIR.h @@ -18,6 +18,7 @@ #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/InitAllDialects.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" @@ -34,20 +35,28 @@ mlir::vector::VectorDialect, mlir::math::MathDialect, \ mlir::complex::ComplexDialect, mlir::DLTIDialect +#define FLANG_CODEGEN_DIALECT_LIST FIRCodeGenDialect, mlir::LLVM::LLVMDialect + // The definitive list of dialects used by flang. #define FLANG_DIALECT_LIST \ - FLANG_NONCODEGEN_DIALECT_LIST, FIRCodeGenDialect, mlir::LLVM::LLVMDialect + FLANG_NONCODEGEN_DIALECT_LIST, FLANG_CODEGEN_DIALECT_LIST inline void registerNonCodegenDialects(mlir::DialectRegistry ®istry) { registry.insert(); + mlir::func::registerInlinerExtension(registry); } /// Register all the dialects used by flang. inline void registerDialects(mlir::DialectRegistry ®istry) { - registry.insert(); + registerNonCodegenDialects(registry); + registry.insert(); } inline void loadNonCodegenDialects(mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registerNonCodegenDialects(registry); + context.appendDialectRegistry(registry); + context.loadDialect(); } @@ -55,6 +64,10 @@ /// pass, but a producer of FIR and MLIR. It is therefore a requirement that the /// dialects be preloaded to be able to build the IR. inline void loadDialects(mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registerDialects(registry); + context.appendDialectRegistry(registry); + context.loadDialect(); } diff --git a/flang/lib/Frontend/CMakeLists.txt b/flang/lib/Frontend/CMakeLists.txt --- a/flang/lib/Frontend/CMakeLists.txt +++ b/flang/lib/Frontend/CMakeLists.txt @@ -1,4 +1,5 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_flang_library(flangFrontend CompilerInstance.cpp @@ -18,6 +19,7 @@ HLFIRDialect MLIRIR ${dialect_libs} + ${extension_libs} LINK_LIBS FortranParser @@ -39,6 +41,7 @@ MLIRSCFToControlFlow MLIRTargetLLVMIRImport ${dialect_libs} + ${extension_libs} LINK_COMPONENTS Passes diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -1,4 +1,5 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_flang_library(FortranLower Allocatable.cpp @@ -33,6 +34,7 @@ FIRTransforms HLFIRDialect ${dialect_libs} + ${extension_libs} LINK_LIBS FIRDialect @@ -42,6 +44,7 @@ FIRTransforms HLFIRDialect ${dialect_libs} + ${extension_libs} FortranCommon FortranParser FortranEvaluate diff --git a/flang/lib/Optimizer/Builder/CMakeLists.txt b/flang/lib/Optimizer/Builder/CMakeLists.txt --- a/flang/lib/Optimizer/Builder/CMakeLists.txt +++ b/flang/lib/Optimizer/Builder/CMakeLists.txt @@ -1,4 +1,5 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_flang_library(FIRBuilder BoxValue.cpp @@ -31,6 +32,7 @@ FIRDialect HLFIRDialect ${dialect_libs} + ${extension_libs} LINK_LIBS FIRDialect @@ -38,4 +40,5 @@ FIRSupport HLFIRDialect ${dialect_libs} + ${extension_libs} ) diff --git a/flang/lib/Optimizer/Support/CMakeLists.txt b/flang/lib/Optimizer/Support/CMakeLists.txt --- a/flang/lib/Optimizer/Support/CMakeLists.txt +++ b/flang/lib/Optimizer/Support/CMakeLists.txt @@ -1,4 +1,5 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_flang_library(FIRSupport InitFIR.cpp @@ -9,9 +10,11 @@ HLFIROpsIncGen MLIRIR ${dialect_libs} + ${extension_libs} LINK_LIBS ${dialect_libs} + ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation diff --git a/flang/tools/bbc/CMakeLists.txt b/flang/tools/bbc/CMakeLists.txt --- a/flang/tools/bbc/CMakeLists.txt +++ b/flang/tools/bbc/CMakeLists.txt @@ -10,6 +10,7 @@ llvm_update_compile_flags(bbc) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(bbc PRIVATE FIRDialect FIRDialectSupport @@ -19,6 +20,7 @@ HLFIRDialect HLFIRTransforms ${dialect_libs} +${extension_libs} MLIRAffineToStandard MLIRSCFToControlFlow FortranCommon diff --git a/flang/tools/fir-opt/CMakeLists.txt b/flang/tools/fir-opt/CMakeLists.txt --- a/flang/tools/fir-opt/CMakeLists.txt +++ b/flang/tools/fir-opt/CMakeLists.txt @@ -1,6 +1,7 @@ add_flang_tool(fir-opt fir-opt.cpp) llvm_update_compile_flags(fir-opt) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) if(FLANG_INCLUDE_TESTS) set(test_libs @@ -18,6 +19,7 @@ FIRAnalysis ${test_libs} ${dialect_libs} + ${extension_libs} # TODO: these should be transitive dependencies from a target providing # "registerFIRPasses()" diff --git a/flang/tools/tco/CMakeLists.txt b/flang/tools/tco/CMakeLists.txt --- a/flang/tools/tco/CMakeLists.txt +++ b/flang/tools/tco/CMakeLists.txt @@ -5,6 +5,7 @@ add_flang_tool(tco tco.cpp) llvm_update_compile_flags(tco) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(tco PRIVATE FIRCodeGen FIRDialect @@ -15,13 +16,13 @@ HLFIRDialect HLFIRTransforms ${dialect_libs} + ${extension_libs} MLIRIR MLIRLLVMDialect MLIRBuiltinToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRTargetLLVMIRExport MLIRPass - MLIRFuncToLLVM MLIRTransforms MLIRAffineToStandard MLIRAnalysis diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt --- a/flang/unittests/Optimizer/CMakeLists.txt +++ b/flang/unittests/Optimizer/CMakeLists.txt @@ -1,4 +1,5 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(LIBS FIRBuilder @@ -8,6 +9,7 @@ FIRSupport HLFIRDialect ${dialect_libs} + ${extension_libs} LLVMTargetParser ) diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -588,6 +588,12 @@ add_mlir_library(${ARGV} DEPENDS mlir-headers) endfunction(add_mlir_conversion_library) +# Declare the library associated with an extension. +function(add_mlir_extension_library name) + set_property(GLOBAL APPEND PROPERTY MLIR_EXTENSION_LIBS ${name}) + add_mlir_library(${ARGV} DEPENDS mlir-headers) +endfunction(add_mlir_extension_library) + # Declare the library associated with a translation. function(add_mlir_translation_library name) set_property(GLOBAL APPEND PROPERTY MLIR_TRANSLATION_LIBS ${name}) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -24,6 +24,7 @@ get_property(MLIR_ALL_LIBS GLOBAL PROPERTY MLIR_ALL_LIBS) get_property(MLIR_DIALECT_LIBS GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(MLIR_CONVERSION_LIBS GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(MLIR_EXTENSION_LIBS GLOBAL PROPERTY MLIR_EXTENSION_LIBS) get_property(MLIR_TRANSLATION_LIBS GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) # Generate MlirConfig.cmake for the build tree. diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -21,6 +21,7 @@ set_property(GLOBAL PROPERTY MLIR_ALL_LIBS "@MLIR_ALL_LIBS@") set_property(GLOBAL PROPERTY MLIR_DIALECT_LIBS "@MLIR_DIALECT_LIBS@") set_property(GLOBAL PROPERTY MLIR_CONVERSION_LIBS "@MLIR_CONVERSION_LIBS@") +set_property(GLOBAL PROPERTY MLIR_EXTENSION_LIBS "@MLIR_EXTENSION_LIBS@") set_property(GLOBAL PROPERTY MLIR_TRANSLATION_LIBS "@MLIR_TRANSLATION_LIBS@") # Provide all our library targets to users. diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -28,9 +28,11 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch5 PRIVATE ${dialect_libs} + ${extension_libs} MLIRAnalysis MLIRCallInterfaces MLIRCastInterfaces diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "toy/Dialect.h" #include "toy/MLIRGen.h" #include "toy/Parser.h" @@ -107,7 +108,10 @@ } int dumpMLIR() { - mlir::MLIRContext context; + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); + + mlir::MLIRContext context(registry); // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt --- a/mlir/examples/toy/Ch6/CMakeLists.txt +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -39,10 +39,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch6 PRIVATE ${dialect_libs} ${conversion_libs} + ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "toy/Dialect.h" #include "toy/MLIRGen.h" #include "toy/Parser.h" @@ -289,8 +290,10 @@ return dumpAST(); // If we aren't dumping the AST, then we are compiling with/to MLIR. + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); - mlir::MLIRContext context; + mlir::MLIRContext context(registry); // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -38,10 +38,12 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch7 PRIVATE ${dialect_libs} ${conversion_libs} + ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "toy/Dialect.h" #include "toy/MLIRGen.h" #include "toy/Parser.h" @@ -290,8 +291,10 @@ return dumpAST(); // If we aren't dumping the AST, then we are compiling with/to MLIR. + mlir::DialectRegistry registry; + mlir::func::registerAllExtensions(registry); - mlir::MLIRContext context; + mlir::MLIRContext context(registry); // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); diff --git a/mlir/include/mlir/Dialect/Func/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/AllExtensions.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/Extensions/AllExtensions.h @@ -0,0 +1,30 @@ +//===- AllExtensions.h - All Func Extensions --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a common entry point for registering all extensions to the +// func dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H +#define MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H + +namespace mlir { +class DialectRegistry; + +namespace func { +/// Register all extensions of the func dialect. This should generally only be +/// used by tools, or other use cases that really do want *all* extensions of +/// the dialect. All other cases should prefer to instead register the specific +/// extensions they intend to take advantage of. +void registerAllExtensions(DialectRegistry ®istry); +} // namespace func + +} // namespace mlir + +#endif // MLIR_DIALECT_FUNC_EXTENSIONS_ALLEXTENSIONS_H diff --git a/mlir/include/mlir/Dialect/Func/Extensions/InlinerExtension.h b/mlir/include/mlir/Dialect/Func/Extensions/InlinerExtension.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/Extensions/InlinerExtension.h @@ -0,0 +1,27 @@ +//===- InlinerExtension.h - Func Inliner Extension 0000----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an extension for the func dialect that implements the +// interfaces necessary to support inlining. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H +#define MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H + +namespace mlir { +class DialectRegistry; + +namespace func { +/// Register the extension used to support inlining the func dialect. +void registerInlinerExtension(DialectRegistry ®istry); +} // namespace func + +} // namespace mlir + +#endif // MLIR_DIALECT_FUNC_EXTENSIONS_INLINEREXTENSION_H diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -21,7 +21,6 @@ def Func_Dialect : Dialect { let name = "func"; let cppNamespace = "::mlir::func"; - let dependentDialects = ["cf::ControlFlowDialect"]; let hasConstantMaterializer = 1; let usePropertiesForAttributes = 1; } diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -285,6 +285,14 @@ private: /// Returns the impl interface instance for the given type. static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { +#ifndef NDEBUG + // Check that the current interface isn't an unresolved promise for the + // given attribute. + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + attr.getDialect(), ConcreteType::getInterfaceID(), + llvm::getTypeName()); +#endif + return attr.getAbstractAttribute().getInterface(); } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -159,11 +159,20 @@ /// Lookup an interface for the given ID if one is registered, otherwise /// nullptr. DialectInterface *getRegisteredInterface(TypeID interfaceID) { +#ifndef NDEBUG + handleUseOfUndefinedPromisedInterface(interfaceID); +#endif + auto it = registeredInterfaces.find(interfaceID); return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; } template InterfaceT *getRegisteredInterface() { +#ifndef NDEBUG + handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(), + llvm::getTypeName()); +#endif + return static_cast( getRegisteredInterface(InterfaceT::getInterfaceID())); } @@ -196,6 +205,37 @@ return *interface; } + /// Declare that the given interface will be implemented, but has a delayed + /// registration. The promised interface type can be an interface of any type + /// not just a dialect interface, i.e. it may also be an + /// AttributeInterface/OpInterface/TypeInterface/etc. + template + void declarePromisedInterface() { + unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID()); + } + + /// Checks if the given interface, which is attempting to be used, is a + /// promised interface of this dialect that has yet to be implemented. If so, + /// emits a fatal error. `interfaceName` is an optional string that contains a + /// more user readable name for the interface (such as the class name). + void handleUseOfUndefinedPromisedInterface(TypeID interfaceID, + StringRef interfaceName = "") { + if (unresolvedPromisedInterfaces.count(interfaceID)) { + llvm::report_fatal_error( + "checking for an interface (`" + interfaceName + + "`) that was promised by dialect '" + getNamespace() + + "' but never implemented. This is generally an indication " + "that the dialect extension implementing the interface was never " + "registered."); + } + } + /// Checks if the given interface, which is attempting to be attached to a + /// construct owned by this dialect, is a promised interface of this dialect + /// that has yet to be implemented. If so, it resolves the interface promise. + void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) { + unresolvedPromisedInterfaces.erase(interfaceID); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. @@ -289,6 +329,11 @@ /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; + /// A set of interfaces that the dialect (or its constructs, i.e. + /// Attributes/Operations/Types/etc.) has promised to implement, but has yet + /// to provide an implementation for. + DenseSet unresolvedPromisedInterfaces; + friend class DialectRegistry; friend void registerDialect(); friend class MLIRContext; diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -97,6 +97,22 @@ } }; +namespace dialect_extension_detail { + +/// Checks if the given interface, which is attempting to be used, is a +/// promised interface of this dialect that has yet to be implemented. If so, +/// emits a fatal error. +void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID, + StringRef interfaceName); + +/// Checks if the given interface, which is attempting to be attached, is a +/// promised interface of this dialect that has yet to be implemented. If so, +/// the promised interface is marked as resolved. +void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, + TypeID interfaceID); + +} // namespace dialect_extension_detail + //===----------------------------------------------------------------------===// // DialectRegistry //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -2083,6 +2083,16 @@ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { OperationName name = op->getName(); +#ifndef NDEBUG + // Check that the current interface isn't an unresolved promise for the + // given operation. + if (Dialect *dialect = name.getDialect()) { + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + *dialect, ConcreteType::getInterfaceID(), + llvm::getTypeName()); + } +#endif + // Access the raw interface from the operation info. if (std::optional rInfo = name.getRegisteredInfo()) { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -18,6 +18,7 @@ #include "mlir/IR/BlockSupport.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Location.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" @@ -348,6 +349,11 @@ /// interfaces for the concrete operation. template void attachInterface() { + // Handle the case where the models resolve a promised interface. + (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + *getDialect(), Models::Interface::getInterfaceID()), + ...); + getImpl()->getInterfaceMap().insertModels(); } diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -14,6 +14,7 @@ #define MLIR_IR_STORAGEUNIQUERSUPPORT_H #include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/Support/InterfaceSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StorageUniquer.h" @@ -160,6 +161,11 @@ llvm::report_fatal_error("Registering an interface for an attribute/type " "that is not itself registered."); + // Handle the case where the models resolve a promised interface. + (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + abstract->getDialect(), IfaceModels::Interface::getInterfaceID()), + ...); + (checkInterfaceTarget(), ...); abstract->interfaceMap.template insertModels(); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -269,6 +269,14 @@ private: /// Returns the impl interface instance for the given type. static typename InterfaceBase::Concept *getInterfaceFor(Type type) { +#ifndef NDEBUG + // Check that the current interface isn't an unresolved promise for the + // given type. + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + type.getDialect(), ConcreteType::getInterfaceID(), + llvm::getTypeName()); +#endif + return type.getAbstractType().getInterface(); } diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/InitAllExtensions.h @@ -0,0 +1,34 @@ +//===- InitAllExtensions.h - MLIR Extension Registration --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialect +// extensions to the system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INITALLEXTENSIONS_H_ +#define MLIR_INITALLEXTENSIONS_H_ + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" + +#include + +namespace mlir { + +/// This function may be called to register all MLIR dialect extensions with the +/// provided registry. +/// If you're building a compiler, you generally shouldn't use this: you would +/// individually register the specific extensions that are useful for the +/// pipelines and transformations you are using. +inline void registerAllExtensions(DialectRegistry ®istry) { + func::registerAllExtensions(registry); +} + +} // namespace mlir + +#endif // MLIR_INITALLEXTENSIONS_H_ diff --git a/mlir/lib/Dialect/Func/CMakeLists.txt b/mlir/lib/Dialect/Func/CMakeLists.txt --- a/mlir/lib/Dialect/Func/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/CMakeLists.txt @@ -1,2 +1,3 @@ +add_subdirectory(Extensions) add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp @@ -0,0 +1,16 @@ +//===- AllExtensions.cpp - All Func Dialect Extensions --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" + +using namespace mlir; + +void mlir::func::registerAllExtensions(DialectRegistry ®istry) { + registerInlinerExtension(registry); +} diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LLVM_OPTIONAL_SOURCES + AllExtensions.cpp + InlinerExtension.cpp + ) + +add_mlir_extension_library(MLIRFuncInlinerExtension + InlinerExtension.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions + + LINK_LIBS PUBLIC + MLIRControlFlowDialect + MLIRInferTypeOpInterface + MLIRIR + MLIRFuncDialect + ) + +add_mlir_extension_library(MLIRFuncAllExtensions + AllExtensions.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions + + LINK_LIBS PUBLIC + MLIRFuncInlinerExtension + ) diff --git a/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp b/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp @@ -0,0 +1,90 @@ +//===- InlinerExtension.cpp - Func Inliner Extension ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::func; + +//===----------------------------------------------------------------------===// +// FuncDialect Interfaces +//===----------------------------------------------------------------------===// +namespace { +/// This class defines the interface for handling inlining with func operations. +struct FuncInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All call operations can be inlined. + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + + /// All operations can be inlined. + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } + + /// All functions can be inlined. + bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, Block *newDest) const final { + // Only return needs to be handled here. + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + + // Replace the return with a branch to the dest. + OpBuilder builder(op); + builder.create(op->getLoc(), newDest, returnOp.getOperands()); + op->erase(); + } + + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only return needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +void mlir::func::registerInlinerExtension(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) { + dialect->addInterfaces(); + + // The inliner extension relies on the ControlFlow dialect. + ctx->getOrLoadDialect(); + }); +} diff --git a/mlir/lib/Dialect/Func/IR/CMakeLists.txt b/mlir/lib/Dialect/Func/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Func/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/IR/CMakeLists.txt @@ -9,7 +9,6 @@ LINK_LIBS PUBLIC MLIRCallInterfaces - MLIRControlFlowDialect MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -8,8 +8,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionImplementation.h" @@ -33,67 +31,6 @@ using namespace mlir; using namespace mlir::func; -//===----------------------------------------------------------------------===// -// FuncDialect Interfaces -//===----------------------------------------------------------------------===// -namespace { -/// This class defines the interface for handling inlining with func operations. -struct FuncInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// All call operations can be inlined. - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - return true; - } - - /// All operations can be inlined. - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - - /// All functions can be inlined. - bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { - return true; - } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, Block *newDest) const final { - // Only return needs to be handled here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the return with a branch to the dest. - OpBuilder builder(op); - builder.create(op->getLoc(), newDest, returnOp.getOperands()); - op->erase(); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } -}; -} // namespace - //===----------------------------------------------------------------------===// // FuncDialect //===----------------------------------------------------------------------===// @@ -103,7 +40,7 @@ #define GET_OP_LIST #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" >(); - addInterfaces(); + declarePromisedInterface(); } /// Materialize a single constant operation from a given attribute value with diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -96,6 +96,9 @@ /// Register a set of dialect interfaces with this dialect instance. void Dialect::addInterface(std::unique_ptr interface) { + // Handle the case where the models resolve a promised interface. + handleAdditionOfUndefinedPromisedInterface(interface->getID()); + auto it = registeredInterfaces.try_emplace(interface->getID(), std::move(interface)); (void)it; @@ -143,6 +146,16 @@ DialectExtensionBase::~DialectExtensionBase() = default; +void dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + Dialect &dialect, TypeID interfaceID, StringRef interfaceName) { + dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName); +} + +void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + Dialect &dialect, TypeID interfaceID) { + dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID); +} + //===----------------------------------------------------------------------===// // DialectRegistry //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1725,10 +1725,10 @@ fprintf(stderr, "@registration\n"); // CHECK-LABEL: @registration - // CHECK: cf.cond_br is_registered: 1 - fprintf(stderr, "cf.cond_br is_registered: %d\n", + // CHECK: func.call is_registered: 1 + fprintf(stderr, "func.call is_registered: %d\n", mlirContextIsRegisteredOperation( - ctx, mlirStringRefCreateFromCString("cf.cond_br"))); + ctx, mlirStringRefCreateFromCString("func.call"))); // CHECK: func.not_existing_op is_registered: 0 fprintf(stderr, "func.not_existing_op is_registered: %d\n", @@ -1942,6 +1942,7 @@ registerAllUpstreamDialects(ctx); mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith")); MlirLocation loc = mlirLocationUnknownGet(ctx); MlirType indexType = mlirIndexTypeGet(ctx); MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -4,6 +4,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(LLVM_LINK_COMPONENTS Core Support @@ -50,7 +51,9 @@ set(LIBS ${dialect_libs} ${conversion_libs} + ${extension_libs} ${test_libs} + MLIRAffineAnalysis MLIRAnalysis MLIRCastInterfaces diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -259,6 +260,8 @@ #endif DialectRegistry registry; registerAllDialects(registry); + registerAllExtensions(registry); + #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt --- a/mlir/unittests/Interfaces/CMakeLists.txt +++ b/mlir/unittests/Interfaces/CMakeLists.txt @@ -7,6 +7,7 @@ target_link_libraries(MLIRInterfacesTests PRIVATE + MLIRArithDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDLTIDialect