diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -140,6 +140,17 @@ getRegisteredInterface(InterfaceT::getInterfaceID())); } + /// Lookup an op interface for the given ID if one is registered, otherwise + /// nullptr. + virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, Operation *op) { + return nullptr; + } + template + typename InterfaceT::Concept *getRegisteredInterfaceForOp(Operation *op) { + return static_cast( + getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), op)); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -272,6 +272,9 @@ // If this dialect overrides the hook for verifying region result attributes. bit hasRegionResultAttrVerify = 0; + + // If this dialect overrides the hook for op interface fallback. + bit hasOperationInterfaceFallback = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -19,6 +19,7 @@ #ifndef MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -1735,7 +1736,18 @@ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. auto *abstractOp = op->getAbstractOperation(); - return abstractOp ? abstractOp->getInterface() : nullptr; + if (abstractOp) { + if (auto *opIface = abstractOp->getInterface()) + return opIface; + // Fallback to the dialect to provide it with a chance to implement this + // interface for this operation. + return abstractOp->dialect.getRegisteredInterfaceForOp(op); + } + // Fallback to the dialect to provide it with a chance to implement this + // interface for this operation. + Dialect *dialect = op->getName().getIdentifier().getDialect(); + return dialect ? dialect->getRegisteredInterfaceForOp(op) + : nullptr; } /// Allow access to `getInterfaceFor`. diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -71,6 +71,8 @@ public: using Concept = typename Traits::Concept; template using Model = typename Traits::template Model; + template + using FallbackModel = typename Traits::template FallbackModel; using InterfaceBase = Interface; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -63,6 +63,9 @@ /// Returns true if this dialect has a region result attribute verifier. bool hasRegionResultAttrVerify() const; + /// Returns true if this dialect has fallback interfaces for its operations. + bool hasOperationInterfaceFallback() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -77,6 +77,10 @@ return def->getValueAsBit("hasRegionResultAttrVerify"); } +bool Dialect::hasOperationInterfaceFallback() const { + return def->getValueAsBit("hasOperationInterfaceFallback"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir --- a/mlir/test/IR/test-side-effects.mlir +++ b/mlir/test/IR/test-side-effects.mlir @@ -30,3 +30,9 @@ %4 = "test.side_effect_op"() { effect_parameter = affine_map<(i, j) -> (j, i)> } : () -> i32 + +// Check with this unregistered operation to test the fallback on the dialect. +// expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}} +%5 = "test.unregistered_side_effect_op"() { + effect_parameter = affine_map<(i, j) -> (j, i)> +} : () -> i32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -166,6 +166,29 @@ // TestDialect //===----------------------------------------------------------------------===// +static void testSideEffectOpGetEffect( + Operation *op, + SmallVectorImpl> &effects); + +// This is the implementation of a dialect fallback for `TestEffectOpInterface`. +struct TestOpEffectInterfaceFallback + : public TestEffectOpInterface::FallbackModel< + TestOpEffectInterfaceFallback> { + static bool classof(Operation *op) { + bool isSupportedOp = + op->getName().getStringRef() == "test.unregistered_side_effect_op"; + assert(isSupportedOp && "Unexpected dispatch"); + return isSupportedOp; + } + + void + getEffects(Operation *op, + SmallVectorImpl> + &effects) const { + testSideEffectOpGetEffect(op, effects); + } +}; + void TestDialect::initialize() { registerAttributes(); registerTypes(); @@ -176,6 +199,14 @@ addInterfaces(); allowUnknownOperations(); + + // Instantiate our fallback op interface that we'll use on specific + // unregistered op. + fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; +} +TestDialect::~TestDialect() { + delete static_cast( + fallbackEffectOpInterfaces); } Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, @@ -183,6 +214,15 @@ return builder.create(loc, type, value); } +void *TestDialect::getRegisteredInterfaceForOp(mlir::TypeID typeID, + mlir::Operation *op) { + StringRef opName = op->getName().getStringRef(); + if (opName == "test.unregistered_side_effect_op" && + typeID == TypeID::get()) + return fallbackEffectOpInterfaces; + return nullptr; +} + LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") @@ -696,6 +736,17 @@ }; } // end anonymous namespace +static void testSideEffectOpGetEffect( + Operation *op, + SmallVectorImpl> + &effects) { + auto effectsAttr = op->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); +} + void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. @@ -734,11 +785,7 @@ void SideEffectOp::getEffects( SmallVectorImpl &effects) { - auto effectsAttr = (*this)->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); + testSideEffectOpGetEffect(getOperation(), effects); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -29,6 +29,7 @@ let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; + let hasOperationInterfaceFallback = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ @@ -39,6 +40,11 @@ Type type) const override; void printAttribute(Attribute attr, DialectAsmPrinter &printer) const override; + ~TestDialect(); + private: + // Storage for a custom fallback interface. + void *fallbackEffectOpInterfaces; + }]; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -143,6 +143,13 @@ ::mlir::NamedAttribute attribute) override; )"; +/// The code block for the op interface fallback hook. +static const char *const operationInterfaceFallbackDecl = R"( + /// Provides a hook for op interface. + void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, + mlir::Operation *op) override; +)"; + /// Generate the declaration for the given dialect class. static void emitDialectDecl(Dialect &dialect, iterator_range dialectAttrs, @@ -181,6 +188,8 @@ os << regionArgAttrVerifierDecl; if (dialect.hasRegionResultAttrVerify()) os << regionResultAttrVerifierDecl; + if (dialect.hasOperationInterfaceFallback()) + os << operationInterfaceFallbackDecl; if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) os << *extraDecl; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -215,23 +215,25 @@ } void InterfaceGenerator::emitModelDecl(Interface &interface) { - os << " template\n"; - os << " class Model : public Concept {\n public:\n"; - os << " Model() : Concept{"; - llvm::interleaveComma( - interface.getMethods(), os, - [&](const InterfaceMethod &method) { os << method.getName(); }); - os << "} {}\n\n"; - - // Insert each of the virtual method overrides. - for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " static inline "); - emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), - /*addConst=*/false); - os << ";\n"; + for (auto modelClass : {"Model", "FallbackModel"}) { + os << " template\n"; + os << " class " << modelClass << " : public Concept {\n public:\n"; + os << " " << modelClass << "() : Concept{"; + llvm::interleaveComma( + interface.getMethods(), os, + [&](const InterfaceMethod &method) { os << method.getName(); }); + os << "} {}\n\n"; + + // Insert each of the virtual method overrides. + for (auto &method : interface.getMethods()) { + emitCPPType(method.getReturnType(), os << " static inline "); + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << ";\n"; + } + os << " };\n"; } - os << " };\n"; } void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { @@ -268,6 +270,32 @@ [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); os << ");\n}\n"; } + + for (auto &method : interface.getMethods()) { + os << "template\n"; + emitCPPType(method.getReturnType(), os); + os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" + << valueTemplate << ">::"; + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << " {\n "; + + // Forward to the method on the concrete Model implementation. + if (method.isStatic()) + os << "return " << valueTemplate << "::"; + else + os << "return static_cast(impl)->"; + + // Add the arguments to the call. + os << method.getName() << '('; + if (!method.isStatic()) + os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma( + method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); + os << ");\n}\n"; + } } void InterfaceGenerator::emitTraitDecl(Interface &interface, diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -171,7 +171,7 @@ StringRef methodName = availability.getQueryFnName(); os << availability.getQueryFnRetType() << " " << availability.getInterfaceClassName() << "::" << methodName << "() {\n" - << " return getImpl()->" << methodName << "(getOperation());\n" + << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" << "}\n"; } @@ -208,23 +208,25 @@ << " virtual ~Concept() = default;\n" << " virtual " << availability.getQueryFnRetType() << " " << availability.getQueryFnName() - << "(Operation *tblgen_opaque_op) const = 0;\n" + << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" << " };\n"; } static void emitModelDecl(const Availability &availability, raw_ostream &os) { - os << " template\n"; - os << " class Model : public Concept {\n" - << " public:\n" - << " " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() - << "(Operation *tblgen_opaque_op) const final {\n" - << " auto op = llvm::cast(tblgen_opaque_op);\n" - << " (void)op;\n" - // Forward to the method on the concrete operation type. - << " return op." << availability.getQueryFnName() << "();\n" - << " }\n" - << " };\n"; + for (auto modelClass : {"Model", "FallbackModel"}) { + os << " template\n"; + os << " class " << modelClass << " : public Concept {\n" + << " public:\n" + << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" + << " auto op = llvm::cast(tblgen_opaque_op);\n" + << " (void)op;\n" + // Forward to the method on the concrete operation type. + << " return op." << availability.getQueryFnName() << "();\n" + << " }\n" + << " };\n"; + } } static void emitInterfaceDecl(const Availability &availability,