diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1619,6 +1619,19 @@ reinterpret_cast(const_cast(pointer))); } + /// Attach the given models as implementations of the corresponding interfaces + /// for the concrete operation. + template + static void attachInterface(MLIRContext &context) { + AbstractOperation *abstract = AbstractOperation::lookupMutable( + ConcreteType::getOperationName(), &context); + if (!abstract) + llvm::report_fatal_error( + "attempting to attach an interface to an unregistered operation " + + ConcreteType::getOperationName()); + abstract->interfaceMap.insert(); + } + private: /// Trait to check if T provides a 'fold' method for a single result op. template diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -162,7 +162,9 @@ /// Look up the specified operation in the specified MLIRContext and return a /// pointer to it if present. Otherwise, return a null pointer. static const AbstractOperation *lookup(StringRef opName, - MLIRContext *context); + MLIRContext *context) { + return lookupMutable(opName, context); + } /// This constructor is used by Dialect objects when they register the list of /// operations they contain. @@ -194,6 +196,15 @@ GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait); + /// Give Op access to lookupMutable. + template class... Traits> + friend class Op; + + /// Look up the specified operation in the specified MLIRContext and return a + /// pointer to it if present. Otherwise, return a null pointer. + static AbstractOperation *lookupMutable(StringRef opName, + MLIRContext *context); + /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -696,8 +696,8 @@ /// Look up the specified operation in the operation set and return a pointer /// to it if present. Otherwise, return a null pointer. -const AbstractOperation *AbstractOperation::lookup(StringRef opName, - MLIRContext *context) { +AbstractOperation *AbstractOperation::lookupMutable(StringRef opName, + MLIRContext *context) { auto &impl = context->getImpl(); auto it = impl.registeredOperations.find(opName); if (it != impl.registeredOperations.end()) diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -89,6 +89,24 @@ ]; } +def TestExternalOpInterface : OpInterface<"TestExternalOpInterface"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"Returns the length of the operation name plus arg.", + "unsigned", "getNameLengthPlusArg", (ins "unsigned":$arg)>, + StaticInterfaceMethod< + "Returns the length of the operation name plus arg twice.", "unsigned", + "getNameLengthPlusArgTwice", (ins "unsigned":$arg)>, + InterfaceMethod< + "Returns the length of the product of the operation name and arg.", + "unsigned", "getNameLengthTimesArg", (ins "unsigned":$arg), "", + [{return arg * $_op->getName().getStringRef().size();}]>, + StaticInterfaceMethod<"Returns the length of the operation name minus arg.", + "unsigned", "getNameLengthMinusArg", (ins "unsigned":$arg), "", + [{return ConcreteOp::getOperationName().size() - arg;}]>, + ]; +} + def TestEffectOpInterface : EffectOpInterfaceBase<"TestEffectOpInterface", "::mlir::TestEffects::Effect"> { diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -12,10 +12,12 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" #include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" #include "../../test/lib/Dialect/Test/TestTypes.h" using namespace mlir; @@ -150,4 +152,72 @@ EXPECT_EQ(iface.getSomeNumber(), 42); } +/// External interface model for the module operation. Only provides non-default +/// methods. +struct TestExternalOpModel + : public TestExternalOpInterface::ExternalModel { + unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { + return op->getName().getStringRef().size() + arg; + } + + static unsigned getNameLengthPlusArgTwice(unsigned arg) { + return ModuleOp::getOperationName().size() + 2 * arg; + } +}; + +/// External interface model for the func operation. Provides non-deafult and +/// overrides default methods. +struct TestExternalOpOverridingModel + : public TestExternalOpInterface::FallbackModel< + TestExternalOpOverridingModel> { + unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { + return op->getName().getStringRef().size() + arg; + } + + static unsigned getNameLengthPlusArgTwice(unsigned arg) { + return FuncOp::getOperationName().size() + 2 * arg; + } + + unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { + return 42; + } + + static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } +}; + +TEST(InterfaceAttachment, Operation) { + MLIRContext context; + + // Initially, the operation doesn't have the interface. + auto moduleOp = ModuleOp::create(UnknownLoc::get(&context)); + ASSERT_FALSE(isa(moduleOp.getOperation())); + + // We can attach an external interface and now the operaiton has it. + ModuleOp::attachInterface(context); + auto iface = dyn_cast(moduleOp.getOperation()); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getNameLengthPlusArg(10), 16u); + EXPECT_EQ(iface.getNameLengthTimesArg(3), 18u); + EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 42u); + EXPECT_EQ(iface.getNameLengthMinusArg(5), 1u); + + // Default implementation can be overridden. + auto funcOp = FuncOp::create(UnknownLoc::get(&context), "function", + FunctionType::get(&context, {}, {})); + ASSERT_FALSE(isa(funcOp.getOperation())); + FuncOp::attachInterface(context); + iface = dyn_cast(funcOp.getOperation()); + ASSERT_TRUE(iface != nullptr); + EXPECT_EQ(iface.getNameLengthPlusArg(10), 14u); + EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); + EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 20u); + EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); + + // Another context doesn't have the interfaces registered. + MLIRContext other; + auto otherModuleOp = ModuleOp::create(UnknownLoc::get(&other)); + ASSERT_FALSE(isa(otherModuleOp.getOperation())); +} + } // end namespace