diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -50,6 +50,12 @@ return interfaceMap.lookup(); } + /// Returns true if the attribute has the interface with the given ID + /// registered. + bool hasInterface(TypeID interfaceID) const { + return interfaceMap.contains(interfaceID); + } + /// Return the unique identifier representing the concrete attribute class. TypeID getTypeID() const { return typeID; } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -27,8 +27,9 @@ using DialectAllocatorFunction = std::function; using DialectAllocatorFunctionRef = function_ref; -using InterfaceAllocatorFunction = +using DialectInterfaceAllocatorFunction = std::function(Dialect *)>; +using ObjectInterfaceAllocatorFunction = std::function; /// Dialects are groups of MLIR operations, types and attributes, as well as /// behavior associated with the entire group. For example, hooks into other @@ -278,11 +279,19 @@ /// dialects loaded in the Context. The parser in particular will lazily load /// dialects in the Context as operations are encountered. class DialectRegistry { + /// Lists of interfaces that need to be registered when the dialect is loaded. + struct DelayedInterfaces { + /// Dialect interfaces. + SmallVector, 2> + dialectInterfaces; + /// Attribute/Operation/Type interfaces. + SmallVector, 2> + objectInterfaces; + }; + using MapTy = std::map>; - using InterfaceMapTy = - DenseMap, 2>>; + using InterfaceMapTy = DenseMap; public: explicit DialectRegistry() {} @@ -336,7 +345,7 @@ /// the registry. template void addDialectInterface(TypeID interfaceTypeID, - InterfaceAllocatorFunction allocator) { + DialectInterfaceAllocatorFunction allocator) { addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID, allocator); } @@ -351,6 +360,36 @@ }); } + /// Add an external op interface model for an op that belongs to a dialect, + /// both provided as template parameters. The dialect must be present in the + /// registry. + template + void addOpInterface() { + StringRef opName = OpTy::getOperationName(); + StringRef dialectName = opName.split('.').first; + addObjectInterface(dialectName == opName ? "" : dialectName, + ModelTy::Interface::getInterfaceID(), + [](MLIRContext *context) { + OpTy::template attachInterface(*context); + }); + } + + /// Add an external attribute interface model for an attribute type `AttrTy` + /// that is going to belong to `DialectTy`. The dialect must be present in the + /// registry. + template + void addAttrInterface() { + addStorageUserInterface(DialectTy::getDialectNamespace()); + } + + /// Add an external type interface model for an type class `TypeTy` that is + /// going to belong to `DialectTy`. The dialect must be present in the + /// registry. + template + void addTypeInterface() { + addStorageUserInterface(DialectTy::getDialectNamespace()); + } + /// Register any interfaces required for the given dialect (based on its /// TypeID). Users are not expected to call this directly. void registerDelayedInterfaces(Dialect *dialect) const; @@ -359,7 +398,22 @@ /// Add an interface constructed with the given allocation function to the /// dialect identified by its namespace. void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID, - InterfaceAllocatorFunction allocator); + DialectInterfaceAllocatorFunction allocator); + + /// Add an attribute/operation/type interface constructible with the given + /// allocation function to the dialect identified by its namespace. + void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID, + ObjectInterfaceAllocatorFunction allocator); + + /// Add an external model for an attribute/type interface to the dialect + /// identified by its namespace. + template + void addStorageUserInterface(StringRef dialectName) { + addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(), + [](MLIRContext *context) { + ObjectTy::template attachInterface(*context); + }); + } MapTy registry; InterfaceMapTy interfaces; diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -58,6 +58,11 @@ return interfaceMap.lookup(); } + /// Returns true if the type has the interface with the given ID. + bool hasInterface(TypeID interfaceID) const { + return interfaceMap.contains(interfaceID); + } + /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -16,6 +16,7 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/TypeName.h" namespace mlir { @@ -236,8 +237,10 @@ llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { return compare(it.first, id); }); - if (it != interfaces.end() && it->first == id) - llvm::report_fatal_error("Interface already registered"); + if (it != interfaces.end() && it->first == id) { + LLVM_DEBUG(llvm::dbgs() << "Ignoring repeated interface registration"); + continue; + } interfaces.insert(it, element); } } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" @@ -31,7 +32,7 @@ void DialectRegistry::addDialectInterface( StringRef dialectName, TypeID interfaceTypeID, - InterfaceAllocatorFunction allocator) { + DialectInterfaceAllocatorFunction allocator) { assert(allocator && "unexpected null interface allocation function"); auto it = registry.find(dialectName.str()); assert(it != registry.end() && @@ -40,8 +41,8 @@ // Bail out if the interface with the given ID is already in the registry for // the given dialect. We expect a small number (dozens) of interfaces so a // linear search is fine here. - auto &dialectInterfaces = interfaces[it->second.first]; - for (const auto &kvp : dialectInterfaces) { + auto &ifaces = interfaces[it->second.first]; + for (const auto &kvp : ifaces.dialectInterfaces) { if (kvp.first == interfaceTypeID) { LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE @@ -51,7 +52,36 @@ } } - dialectInterfaces.emplace_back(interfaceTypeID, allocator); + ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator); +} + +void DialectRegistry::addObjectInterface( + StringRef dialectName, TypeID interfaceTypeID, + ObjectInterfaceAllocatorFunction allocator) { + assert(allocator && "unexpected null interface allocation function"); + + // Builtin dialect has an empty prefix and is always registered. + TypeID dialectTypeID; + if (!dialectName.empty()) { + auto it = registry.find(dialectName.str()); + assert(it != registry.end() && + "adding an interface for an op from an unregistered dialect"); + dialectTypeID = it->second.first; + } else { + dialectTypeID = TypeID::get(); + } + + auto &ifaces = interfaces[dialectTypeID]; + for (const auto &kvp : ifaces.objectInterfaces) { + if (kvp.first == interfaceTypeID) { + LLVM_DEBUG(llvm::dbgs() + << "[" DEBUG_TYPE + "] repeated interface object interface registration"); + return; + } + } + + ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator); } DialectAllocatorFunctionRef @@ -79,11 +109,15 @@ return; // Add an interface if it is not already present. - for (const auto &kvp : it->second) { + for (const auto &kvp : it->getSecond().dialectInterfaces) { if (dialect->getRegisteredInterface(kvp.first)) continue; dialect->addInterface(kvp.second(dialect)); } + + // Add attribute, operation and type interfaces. + for (const auto &kvp : it->getSecond().objectInterfaces) + kvp.second(dialect->getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -356,12 +356,12 @@ printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); } - // Ensure the builtin dialect is always pre-loaded. - getOrLoadDialect(); - // Pre-populate the registry. registry.appendTo(impl->dialectsRegistry); + // Ensure the builtin dialect is always pre-loaded. + getOrLoadDialect(); + // Initialize several common attributes and types to avoid the need to lock // the context when accessing them. diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" @@ -87,6 +88,74 @@ EXPECT_FALSE(i8other.isa()); } +/// External interface model for the test type from the test dialect. +struct TestTypeModel + : public TestExternalTypeInterface::ExternalModel { + unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } + + static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } +}; + +TEST(InterfaceAttachment, TypeDelayedContextConstruct) { + // Put the interface in the registry. + DialectRegistry registry; + registry.insert(); + registry.addTypeInterface(); + + // Check that when a context is constructed with the given registry, the type + // interface gets registered. + MLIRContext context(registry); + context.loadDialect(); + test::TestType testType = test::TestType::get(&context); + auto iface = testType.dyn_cast(); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); + EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); +} + +TEST(InterfaceAttachment, TypeDelayedContextAppend) { + // Put the interface in the registry. + DialectRegistry registry; + registry.insert(); + registry.addTypeInterface(); + + // Check that when the registry gets appended to the context, the interface + // becomes available for objects in loaded dialects. + MLIRContext context; + context.loadDialect(); + test::TestType testType = test::TestType::get(&context); + EXPECT_FALSE(testType.isa()); + context.appendDialectRegistry(registry); + EXPECT_TRUE(testType.isa()); +} + +TEST(InterfaceAttachment, RepeatedRegistration) { + DialectRegistry registry; + registry.addTypeInterface(); + MLIRContext context(registry); + + // Should't fail on repeated registration through the dialect registry. + context.appendDialectRegistry(registry); +} + +TEST(InterfaceAttachment, TypeBuiltinDelayed) { + // Builtin dialect needs to registration or loading, but delayed interface + // registration must still work. + DialectRegistry registry; + registry.addTypeInterface(); + + MLIRContext context(registry); + IntegerType i16 = IntegerType::get(&context, 16); + EXPECT_TRUE(i16.isa()); + + MLIRContext initiallyEmpty; + IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); + EXPECT_FALSE(i32.isa()); + initiallyEmpty.appendDialectRegistry(registry); + EXPECT_TRUE(i32.isa()); +} + /// The interface provides a default implementation that expects /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this /// just derives from the ExternalModel. @@ -128,9 +197,9 @@ } /// External model for attribute interfaces. -struct TextExternalIntegerAttrModel +struct TestExternalIntegerAttrModel : public TestExternalAttrInterface::ExternalModel< - TextExternalIntegerAttrModel, IntegerAttr> { + TestExternalIntegerAttrModel, IntegerAttr> { const Dialect *getDialectPtr(Attribute attr) const { return &attr.cast().getDialect(); } @@ -145,13 +214,45 @@ // that the basics work for attributes. IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); ASSERT_FALSE(attr.isa()); - IntegerAttr::attachInterface(context); + IntegerAttr::attachInterface(context); auto iface = attr.dyn_cast(); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); EXPECT_EQ(iface.getSomeNumber(), 42); } +/// External model for an interface attachable to a non-builtin attribute. +struct TestExternalSimpleAAttrModel + : public TestExternalAttrInterface::ExternalModel< + TestExternalSimpleAAttrModel, test::SimpleAAttr> { + const Dialect *getDialectPtr(Attribute attr) const { + return &attr.getDialect(); + } + + static int getSomeNumber() { return 21; } +}; + +TEST(InterfaceAttachmentTest, AttributeDelayed) { + // Attribute interfaces use the exact same mechanism as types, so just check + // that the delayed registration work for attributes. + DialectRegistry registry; + registry.insert(); + registry.addAttrInterface(); + + MLIRContext context(registry); + context.loadDialect(); + auto attr = test::SimpleAAttr::get(&context); + EXPECT_TRUE(attr.isa()); + + MLIRContext initiallyEmpty; + initiallyEmpty.loadDialect(); + attr = test::SimpleAAttr::get(&initiallyEmpty); + EXPECT_FALSE(attr.isa()); + initiallyEmpty.appendDialectRegistry(registry); + EXPECT_TRUE(attr.isa()); +} + /// External interface model for the module operation. Only provides non-default /// methods. struct TestExternalOpModel @@ -220,4 +321,55 @@ ASSERT_FALSE(isa(otherModuleOp.getOperation())); } +struct TestExternalTestOpModel + : public TestExternalOpInterface::ExternalModel { + unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { + return op->getName().getStringRef().size() + arg; + } + + static unsigned getNameLengthPlusArgTwice(unsigned arg) { + return test::OpJ::getOperationName().size() + 2 * arg; + } +}; + +TEST(InterfaceAttachment, OperationDelayedContextConstruct) { + DialectRegistry registry; + registry.insert(); + registry.addOpInterface(); + registry.addOpInterface(); + + // Construct the context directly from a registry. The interfaces are expected + // to be readily available on operations. + MLIRContext context(registry); + context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module); + auto op = + builder.create(builder.getUnknownLoc(), builder.getI32Type()); + EXPECT_TRUE(isa(module.getOperation())); + EXPECT_TRUE(isa(op.getOperation())); +} + +TEST(InterfaceAttachment, OperationDelayedContextAppend) { + DialectRegistry registry; + registry.insert(); + registry.addOpInterface(); + registry.addOpInterface(); + + // Construct the context, create ops, and only then append the registry. The + // interfaces are expected to be available after appending the registry. + MLIRContext context; + context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module); + auto op = + builder.create(builder.getUnknownLoc(), builder.getI32Type()); + EXPECT_FALSE(isa(module.getOperation())); + EXPECT_FALSE(isa(op.getOperation())); + context.appendDialectRegistry(registry); + EXPECT_TRUE(isa(module.getOperation())); + EXPECT_TRUE(isa(op.getOperation())); +} + } // end namespace