diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -207,6 +207,45 @@ llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n"; ``` +#### Dialect Fallback for OpInterface + +Some dialects have an open ecosystem and don't register all of the possible +operations. In such cases it is still possible to provide support for +implementing an `OpInterface` for these operation. When an operation isn't +registered or does not provide an implementation for an interface, the query +will fallback to the dialect itself. + +A second model is used for such cases and automatically generated when +using ODS (see below) with the name `FallbackModel`. This model can be implemented +for a particular dialect: + +```c++ +// This is the implementation of a dialect fallback for `ExampleOpInterface`. +struct FallbackExampleOpInterface + : public ExampleOpInterface::FallbackModel< + FallbackExampleOpInterface> { + static bool classof(Operation *op) { return true; } + + unsigned exampleInterfaceHook(Operation *op) const; + unsigned exampleStaticInterfaceHook() const; +}; +``` + +A dialect can then instantiate this implementation and returns it on specific +operations by overriding the `getRegisteredInterfaceForOp` method : + +```c++ +void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, + Identifier opName) { + if (typeID == TypeID::get()) { + if (isSupported(opName)) + return fallbackExampleOpInterface; + return nullptr; + } + return nullptr; +} +``` + #### Utilizing the ODS Framework Note: Before reading this section, the reader should have some familiarity with diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -158,6 +158,19 @@ getRegisteredInterface(InterfaceT::getInterfaceID())); } + /// Lookup an op interface for the given ID if one is registered, otherwise + /// nullptr. + virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, + OperationName opName) { + return nullptr; + } + template + typename InterfaceT::Concept * + getRegisteredInterfaceForOp(OperationName opName) { + return static_cast( + getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -272,6 +272,9 @@ // If this dialect overrides the hook for verifying region result attributes. bit hasRegionResultAttrVerify = 0; + + // If this dialect overrides the hook for op interface fallback. + bit hasOperationInterfaceFallback = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -19,6 +19,7 @@ #ifndef MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -1721,7 +1722,20 @@ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. auto *abstractOp = op->getAbstractOperation(); - return abstractOp ? abstractOp->getInterface() : nullptr; + if (abstractOp) { + if (auto *opIface = abstractOp->getInterface()) + return opIface; + // Fallback to the dialect to provide it with a chance to implement this + // interface for this operation. + return abstractOp->dialect.getRegisteredInterfaceForOp( + op->getName()); + } + // Fallback to the dialect to provide it with a chance to implement this + // interface for this operation. + Dialect *dialect = op->getName().getDialect(); + return dialect ? dialect->getRegisteredInterfaceForOp( + op->getName()) + : nullptr; } /// Allow access to `getInterfaceFor`. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -181,7 +181,7 @@ auto *concept = getInterfaceFor(op); if (!concept) return false; - return !concept->isOptionalSymbol(op) || + return !concept->isOptionalSymbol(concept, op) || op->getAttr(::mlir::SymbolTable::getSymbolAttrName()); } }]; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -71,6 +71,8 @@ public: using Concept = typename Traits::Concept; template using Model = typename Traits::template Model; + template + using FallbackModel = typename Traits::template FallbackModel; using InterfaceBase = Interface; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -63,6 +63,9 @@ /// Returns true if this dialect has a region result attribute verifier. bool hasRegionResultAttrVerify() const; + /// Returns true if this dialect has fallback interfaces for its operations. + bool hasOperationInterfaceFallback() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -77,6 +77,10 @@ return def->getValueAsBit("hasRegionResultAttrVerify"); } +bool Dialect::hasOperationInterfaceFallback() const { + return def->getValueAsBit("hasOperationInterfaceFallback"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir --- a/mlir/test/IR/test-side-effects.mlir +++ b/mlir/test/IR/test-side-effects.mlir @@ -30,3 +30,9 @@ %4 = "test.side_effect_op"() { effect_parameter = affine_map<(i, j) -> (j, i)> } : () -> i32 + +// Check with this unregistered operation to test the fallback on the dialect. +// expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}} +%5 = "test.unregistered_side_effect_op"() { + effect_parameter = affine_map<(i, j) -> (j, i)> +} : () -> i32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -166,6 +166,29 @@ // TestDialect //===----------------------------------------------------------------------===// +static void testSideEffectOpGetEffect( + Operation *op, + SmallVectorImpl> &effects); + +// This is the implementation of a dialect fallback for `TestEffectOpInterface`. +struct TestOpEffectInterfaceFallback + : public TestEffectOpInterface::FallbackModel< + TestOpEffectInterfaceFallback> { + static bool classof(Operation *op) { + bool isSupportedOp = + op->getName().getStringRef() == "test.unregistered_side_effect_op"; + assert(isSupportedOp && "Unexpected dispatch"); + return isSupportedOp; + } + + void + getEffects(Operation *op, + SmallVectorImpl> + &effects) const { + testSideEffectOpGetEffect(op, effects); + } +}; + void TestDialect::initialize() { registerAttributes(); registerTypes(); @@ -176,6 +199,14 @@ addInterfaces(); allowUnknownOperations(); + + // Instantiate our fallback op interface that we'll use on specific + // unregistered op. + fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; +} +TestDialect::~TestDialect() { + delete static_cast( + fallbackEffectOpInterfaces); } Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, @@ -183,6 +214,14 @@ return builder.create(loc, type, value); } +void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, + OperationName opName) { + if (opName.getIdentifier() == "test.unregistered_side_effect_op" && + typeID == TypeID::get()) + return fallbackEffectOpInterfaces; + return nullptr; +} + LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") @@ -716,6 +755,17 @@ }; } // end anonymous namespace +static void testSideEffectOpGetEffect( + Operation *op, + SmallVectorImpl> + &effects) { + auto effectsAttr = op->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); +} + void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. @@ -754,11 +804,7 @@ void SideEffectOp::getEffects( SmallVectorImpl &effects) { - auto effectsAttr = (*this)->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); + testSideEffectOpGetEffect(getOperation(), effects); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -29,6 +29,7 @@ let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; + let hasOperationInterfaceFallback = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -45,6 +46,12 @@ getParseOperationHook(StringRef opName) const override; LogicalResult printOperation(Operation *op, OpAsmPrinter &printer) const override; + + ~TestDialect(); + private: + // Storage for a custom fallback interface. + void *fallbackEffectOpInterfaces; + }]; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -143,6 +143,13 @@ ::mlir::NamedAttribute attribute) override; )"; +/// The code block for the op interface fallback hook. +static const char *const operationInterfaceFallbackDecl = R"( + /// Provides a hook for op interface. + void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, + mlir::OperationName opName) override; +)"; + /// Generate the declaration for the given dialect class. static void emitDialectDecl(Dialect &dialect, iterator_range dialectAttrs, @@ -181,6 +188,8 @@ os << regionArgAttrVerifierDecl; if (dialect.hasRegionResultAttrVerify()) os << regionResultAttrVerifierDecl; + if (dialect.hasOperationInterfaceFallback()) + os << operationInterfaceFallbackDecl; if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) os << *extraDecl; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -43,9 +43,13 @@ raw_ostream &os, StringRef valueType, bool addThisArg, bool addConst) { os << method.getName() << '('; - if (addThisArg) + if (addThisArg) { + if (addConst) + os << "const "; + os << "const Concept *impl, "; emitCPPType(valueType, os) << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); + } llvm::interleaveComma(method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.type << " " << arg.name; @@ -124,7 +128,9 @@ interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; StringRef castCode = "(llvm::cast(tablegen_opaque_val))"; - nonStaticMethodFmt.withOp(castCode).withSelf(castCode); + nonStaticMethodFmt.addSubst("_this", "impl") + .withOp(castCode) + .withSelf(castCode); traitMethodFmt.withOp("(*static_cast(this))"); } }; @@ -167,6 +173,7 @@ // Forward to the method on the concrete operation type. os << " {\n return getImpl()->" << method.getName() << '('; if (!method.isStatic()) { + os << "getImpl(), "; os << (isOpInterface ? "getOperation()" : "*this"); os << (method.arg_empty() ? "" : ", "); } @@ -197,8 +204,10 @@ os << " "; emitCPPType(method.getReturnType(), os); os << "(*" << method.getName() << ")("; - if (!method.isStatic()) + if (!method.isStatic()) { + os << "const Concept *impl, "; emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); + } llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); @@ -208,23 +217,25 @@ } void InterfaceGenerator::emitModelDecl(Interface &interface) { - os << " template\n"; - os << " class Model : public Concept {\n public:\n"; - os << " Model() : Concept{"; - llvm::interleaveComma( - interface.getMethods(), os, - [&](const InterfaceMethod &method) { os << method.getName(); }); - os << "} {}\n\n"; - - // Insert each of the virtual method overrides. - for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " static inline "); - emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), - /*addConst=*/false); - os << ";\n"; + for (const char *modelClass : {"Model", "FallbackModel"}) { + os << " template\n"; + os << " class " << modelClass << " : public Concept {\n public:\n"; + os << " " << modelClass << "() : Concept{"; + llvm::interleaveComma( + interface.getMethods(), os, + [&](const InterfaceMethod &method) { os << method.getName(); }); + os << "} {}\n\n"; + + // Insert each of the virtual method overrides. + for (auto &method : interface.getMethods()) { + emitCPPType(method.getReturnType(), os << " static inline "); + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << ";\n"; + } + os << " };\n"; } - os << " };\n"; } void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { @@ -261,6 +272,32 @@ [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n}\n"; } + + for (auto &method : interface.getMethods()) { + os << "template\n"; + emitCPPType(method.getReturnType(), os); + os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" + << valueTemplate << ">::"; + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << " {\n "; + + // Forward to the method on the concrete Model implementation. + if (method.isStatic()) + os << "return " << valueTemplate << "::"; + else + os << "return static_cast(impl)->"; + + // Add the arguments to the call. + os << method.getName() << '('; + if (!method.isStatic()) + os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma( + method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); + os << ");\n}\n"; + } } void InterfaceGenerator::emitTraitDecl(Interface &interface, diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -171,7 +171,7 @@ StringRef methodName = availability.getQueryFnName(); os << availability.getQueryFnRetType() << " " << availability.getInterfaceClassName() << "::" << methodName << "() {\n" - << " return getImpl()->" << methodName << "(getOperation());\n" + << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" << "}\n"; } @@ -208,23 +208,25 @@ << " virtual ~Concept() = default;\n" << " virtual " << availability.getQueryFnRetType() << " " << availability.getQueryFnName() - << "(Operation *tblgen_opaque_op) const = 0;\n" + << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" << " };\n"; } static void emitModelDecl(const Availability &availability, raw_ostream &os) { - os << " template\n"; - os << " class Model : public Concept {\n" - << " public:\n" - << " " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() - << "(Operation *tblgen_opaque_op) const final {\n" - << " auto op = llvm::cast(tblgen_opaque_op);\n" - << " (void)op;\n" - // Forward to the method on the concrete operation type. - << " return op." << availability.getQueryFnName() << "();\n" - << " }\n" - << " };\n"; + for (const char *modelClass : {"Model", "FallbackModel"}) { + os << " template\n"; + os << " class " << modelClass << " : public Concept {\n" + << " public:\n" + << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" + << " auto op = llvm::cast(tblgen_opaque_op);\n" + << " (void)op;\n" + // Forward to the method on the concrete operation type. + << " return op." << availability.getQueryFnName() << "();\n" + << " }\n" + << " };\n"; + } } static void emitInterfaceDecl(const Availability &availability,