diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/Support/TypeID.h" #include +#include namespace mlir { class DialectAsmParser; @@ -285,7 +286,7 @@ SmallVector, 2> dialectInterfaces; /// Attribute/Operation/Type interfaces. - SmallVector, 2> + SmallVector, 2> objectInterfaces; }; @@ -367,7 +368,8 @@ void addOpInterface() { StringRef opName = OpTy::getOperationName(); StringRef dialectName = opName.split('.').first; - addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(), + addObjectInterface(dialectName, TypeID::get(), + ModelTy::Interface::getInterfaceID(), [](MLIRContext *context) { OpTy::template attachInterface(*context); }); @@ -401,14 +403,16 @@ /// Add an attribute/operation/type interface constructible with the given /// allocation function to the dialect identified by its namespace. - void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID, + void addObjectInterface(StringRef dialectName, TypeID objectID, + TypeID interfaceTypeID, ObjectInterfaceAllocatorFunction allocator); /// Add an external model for an attribute/type interface to the dialect /// identified by its namespace. template void addStorageUserInterface(StringRef dialectName) { - addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(), + addObjectInterface(dialectName, TypeID::get(), + ModelTy::Interface::getInterfaceID(), [](MLIRContext *context) { ObjectTy::template attachInterface(*context); }); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -58,16 +58,19 @@ } void DialectRegistry::addObjectInterface( - StringRef dialectName, TypeID interfaceTypeID, + StringRef dialectName, TypeID objectID, TypeID interfaceTypeID, ObjectInterfaceAllocatorFunction allocator) { assert(allocator && "unexpected null interface allocation function"); + auto it = registry.find(dialectName.str()); assert(it != registry.end() && "adding an interface for an op from an unregistered dialect"); - auto &ifaces = interfaces[it->second.first]; - for (const auto &kvp : ifaces.objectInterfaces) { - if (kvp.first == interfaceTypeID) { + auto dialectID = it->second.first; + auto &ifaces = interfaces[dialectID]; + + for (const auto &info : ifaces.objectInterfaces) { + if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) { LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] repeated interface object interface registration"); @@ -75,7 +78,7 @@ } } - ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator); + ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator); } DialectAllocatorFunctionRef @@ -110,8 +113,8 @@ } // Add attribute, operation and type interfaces. - for (const auto &kvp : it->getSecond().objectInterfaces) - kvp.second(dialect->getContext()); + for (const auto &info : it->getSecond().objectInterfaces) + std::get<2>(info)(dialect->getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -321,15 +321,16 @@ ASSERT_FALSE(isa(otherModuleOp.getOperation())); } +template struct TestExternalTestOpModel - : public TestExternalOpInterface::ExternalModel { + : public TestExternalOpInterface::ExternalModel< + TestExternalTestOpModel, ConcreteOp> { unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { return op->getName().getStringRef().size() + arg; } static unsigned getNameLengthPlusArgTwice(unsigned arg) { - return test::OpJ::getOperationName().size() + 2 * arg; + return ConcreteOp::getOperationName().size() + 2 * arg; } }; @@ -337,39 +338,61 @@ DialectRegistry registry; registry.insert(); registry.addOpInterface(); - registry.addOpInterface(); + registry.addOpInterface>(); + registry.addOpInterface>(); // Construct the context directly from a registry. The interfaces are expected // to be readily available on operations. MLIRContext context(registry); context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module); - auto op = + auto opJ = builder.create(builder.getUnknownLoc(), builder.getI32Type()); + auto opH = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + auto opI = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + EXPECT_TRUE(isa(module.getOperation())); - EXPECT_TRUE(isa(op.getOperation())); + EXPECT_TRUE(isa(opJ.getOperation())); + EXPECT_TRUE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); } TEST(InterfaceAttachment, OperationDelayedContextAppend) { DialectRegistry registry; registry.insert(); registry.addOpInterface(); - registry.addOpInterface(); + registry.addOpInterface>(); + registry.addOpInterface>(); // Construct the context, create ops, and only then append the registry. The // interfaces are expected to be available after appending the registry. MLIRContext context; context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module); - auto op = + auto opJ = builder.create(builder.getUnknownLoc(), builder.getI32Type()); + auto opH = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + auto opI = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + EXPECT_FALSE(isa(module.getOperation())); - EXPECT_FALSE(isa(op.getOperation())); + EXPECT_FALSE(isa(opJ.getOperation())); + EXPECT_FALSE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); + context.appendDialectRegistry(registry); + EXPECT_TRUE(isa(module.getOperation())); - EXPECT_TRUE(isa(op.getOperation())); + EXPECT_TRUE(isa(opJ.getOperation())); + EXPECT_TRUE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); } } // end namespace