diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -285,6 +285,14 @@ private: /// Returns the impl interface instance for the given type. static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) { +#ifndef NDEBUG + // Check that the current interface is not an unresolved promise for the + // given attribute. + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + attr.getDialect(), ConcreteType::getInterfaceID(), + llvm::getTypeName()); +#endif // NDEBUG + return attr.getAbstractAttribute().getInterface(); } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -159,11 +159,20 @@ /// Lookup an interface for the given ID if one is registered, otherwise /// nullptr. DialectInterface *getRegisteredInterface(TypeID interfaceID) { +#ifndef NDEBUG + handleUseOfUndefinedPromisedInterface(interfaceID); +#endif // NDEBUG + auto it = registeredInterfaces.find(interfaceID); return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; } template InterfaceT *getRegisteredInterface() { +#ifndef NDEBUG + handleUseOfUndefinedPromisedInterface(InterfaceT::getInterfaceID(), + llvm::getTypeName()); +#endif // NDEBUG + return static_cast( getRegisteredInterface(InterfaceT::getInterfaceID())); } @@ -196,6 +205,36 @@ return *interface; } + /// Declare that the given interface will be implemented, but has a delayed + /// registration. The promised interface type can be an interface of any type + /// not just a dialect interface, i.e. it may also be an + /// AttributeInterface/OpInterface/TypeInterface/etc. + template void declarePromisedInterface() { + unresolvedPromisedInterfaces.insert(InterfaceT::getInterfaceID()); + } + + /// Checks if the given interface, which is attempting to be used, is a + /// promised interface of this dialect that has yet to be implemented. If so, + /// emits a fatal error. `interfaceName` is an optional string that contains a + /// more user readable name for the interface (such as the class name). + void handleUseOfUndefinedPromisedInterface(TypeID interfaceID, + StringRef interfaceName = "") { + if (unresolvedPromisedInterfaces.count(interfaceID)) { + llvm::report_fatal_error( + "checking for an interface (`" + interfaceName + + "`) that was promised by dialect '" + getNamespace() + + "' but never implemented. This is generally an indication " + "that the dialect extension implementing the interface was never " + "registered."); + } + } + /// Checks if the given interface, which is attempting to be attached to a + /// construct owned by this dialect, is a promised interface of this dialect + /// that has yet to be implemented. If so, it resolves the interface promise. + void handleAdditionOfUndefinedPromisedInterface(TypeID interfaceID) { + unresolvedPromisedInterfaces.erase(interfaceID); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. @@ -289,6 +328,11 @@ /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; + /// A set of interfaces that the dialect (or its constructs, i.e. + /// Attributes/Operations/Types/etc.) has promised to implement, but has yet + /// to provide an implementation for. + DenseSet unresolvedPromisedInterfaces; + friend class DialectRegistry; friend void registerDialect(); friend class MLIRContext; diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -97,6 +97,22 @@ } }; +namespace dialect_extension_detail { + +/// Checks if the given interface, which is attempting to be used, is a +/// promised interface of this dialect that has yet to be implemented. If so, +/// emits a fatal error. +void handleUseOfUndefinedPromisedInterface(Dialect &dialect, TypeID interfaceID, + StringRef interfaceName); + +/// Checks if the given interface, which is attempting to be attached, is a +/// promised interface of this dialect that has yet to be implemented. If so, +/// the promised interface is marked as resolved. +void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, + TypeID interfaceID); + +} // namespace dialect_extension_detail + //===----------------------------------------------------------------------===// // DialectRegistry //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1909,6 +1909,16 @@ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { OperationName name = op->getName(); +#ifndef NDEBUG + // Check that the current interface is not an unresolved promise for the + // given operation. + if (Dialect *dialect = name.getDialect()) { + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + *dialect, ConcreteType::getInterfaceID(), + llvm::getTypeName()); + } +#endif // NDEBUG + // Access the raw interface from the operation info. if (std::optional rInfo = name.getRegisteredInfo()) { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -286,6 +286,12 @@ /// interfaces for the concrete operation. template void attachInterface() { + // Handle the case where the models resolve a promised interface. + (void)std::initializer_list{ + (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + *getDialect(), Models::Interface::getInterfaceID()), + 0)...}; + getImpl()->getInterfaceMap().insertModels(); } diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -14,6 +14,7 @@ #define MLIR_IR_STORAGEUNIQUERSUPPORT_H #include "mlir/IR/AttrTypeSubElements.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/Support/InterfaceSupport.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StorageUniquer.h" @@ -161,6 +162,13 @@ "that is not itself registered."); (checkInterfaceTarget(), ...); + + // Handle the case where the models resolve a promised interface. + (void)std::initializer_list{ + (dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + abstract->getDialect(), IfaceModels::Interface::getInterfaceID()), + 0)...}; + abstract->interfaceMap.template insertModels(); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -268,6 +268,14 @@ private: /// Returns the impl interface instance for the given type. static typename InterfaceBase::Concept *getInterfaceFor(Type type) { +#ifndef NDEBUG + // Check that the current interface isn't an unresolved promise for the + // given type. + dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + type.getDialect(), ConcreteType::getInterfaceID(), + llvm::getTypeName()); +#endif // NDEBUG + return type.getAbstractType().getInterface(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" @@ -119,6 +120,9 @@ >(namedStructuredOpRegionBuilders); addInterfaces(); + + declarePromisedInterface(); + declarePromisedInterface(); } LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -19,6 +19,7 @@ MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR + MLIRRuntimeVerifiableOpInterface MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRViewLikeInterface diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include @@ -40,7 +41,9 @@ #define GET_OP_LIST #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" >(); + addInterfaces(); + declarePromisedInterface(); } /// Finds the unique dealloc operation (if one exists) for `allocValue`. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp @@ -10,6 +10,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; @@ -44,5 +46,9 @@ #define GET_OP_LIST #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" >(); + addInterfaces(); + + declarePromisedInterface(); + declarePromisedInterface(); } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -96,6 +96,9 @@ /// Register a set of dialect interfaces with this dialect instance. void Dialect::addInterface(std::unique_ptr interface) { + // Handle the case where the models resolve a promised interface. + handleAdditionOfUndefinedPromisedInterface(interface->getID()); + auto it = registeredInterfaces.try_emplace(interface->getID(), std::move(interface)); (void)it; @@ -143,6 +146,16 @@ DialectExtensionBase::~DialectExtensionBase() = default; +void dialect_extension_detail::handleUseOfUndefinedPromisedInterface( + Dialect &dialect, TypeID interfaceID, StringRef interfaceName) { + dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName); +} + +void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface( + Dialect &dialect, TypeID interfaceID) { + dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID); +} + //===----------------------------------------------------------------------===// // DialectRegistry //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9963,6 +9963,7 @@ ":InferTypeOpInterface", ":MemRefBaseIncGen", ":MemRefOpsIncGen", + ":RuntimeVerifiableOpInterface", ":ShapedOpInterfaces", ":ViewLikeInterface", "//llvm:Support",