diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -137,6 +137,17 @@ getRegisteredInterface(InterfaceT::getInterfaceID())); } + /// Lookup an op interface for the given ID if one is registered, otherwise + /// nullptr. + virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, Operation *op) { + return nullptr; + } + template + typename InterfaceT::Concept *getRegisteredInterfaceForOp(Operation *op) { + return static_cast( + getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), op)); + } + protected: /// The constructor takes a unique namespace for this dialect as well as the /// context to bind to. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -278,6 +278,9 @@ // If this dialect overrides the hook for verifying region result attributes. bit hasRegionResultAttrVerify = 0; + + // If this dialect overrides the hook for op interface fallback. + bit hasOperationInterfaceFallback = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -19,6 +19,7 @@ #ifndef MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -1742,7 +1743,15 @@ static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. auto *abstractOp = op->getAbstractOperation(); - return abstractOp ? abstractOp->getInterface() : nullptr; + if (abstractOp) { + auto *opIface = abstractOp->getInterface(); + if (opIface) + return opIface; + } + // Fallback to the dialect + if (op->getDialect()) + return op->getDialect()->getRegisteredInterfaceForOp(op); + return nullptr; } /// Allow access to `getInterfaceFor`. diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -186,6 +186,12 @@ return reinterpret_cast(inst); } + /// Returns an instance of the concept object for the given interface if it + /// was registered to this map, null otherwise. + void *lookup(TypeID typeID) const { + return interfaces ? interfaces->lookup(typeID) : nullptr; + } + private: InterfaceMap() = default; InterfaceMap(MutableArrayRef> elements) diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -63,6 +63,9 @@ /// Returns true if this dialect has a region result attribute verifier. bool hasRegionResultAttrVerify() const; + /// Returns true if this dialect has fallback interfaces for its operations. + bool hasOperationInterfaceFallback() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -77,6 +77,10 @@ return def->getValueAsBit("hasRegionResultAttrVerify"); } +bool Dialect::hasOperationInterfaceFallback() const { + return def->getValueAsBit("hasOperationInterfaceFallback"); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir --- a/mlir/test/IR/test-side-effects.mlir +++ b/mlir/test/IR/test-side-effects.mlir @@ -30,3 +30,9 @@ %4 = "test.side_effect_op"() { effect_parameter = affine_map<(i, j) -> (j, i)> } : () -> i32 + +// Check with this unregistered operation to test the fallback on the dialect. +// expected-remark@+1 {{found a parametric effect with affine_map<(d0, d1) -> (d1, d0)>}} +%5 = "test.unregistered_side_effect_op"() { + effect_parameter = affine_map<(i, j) -> (j, i)> +} : () -> i32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -164,6 +164,29 @@ // TestDialect //===----------------------------------------------------------------------===// +namespace { +void testSideEffectOpGetEffect( + Operation *op, SmallVectorImpl<::mlir::SideEffects::EffectInstance< + ::mlir::TestEffects::Effect>> &effects); +} + +// This is the implementation of a dialect fallback for `TestEffectOpInterface`. +struct TestOpEffectInterfaceFallback { + TestOpEffectInterfaceFallback(Operation *op) : op(op) {} + static bool classof(Operation *op) { + assert(op->getName().getStringRef() == "test.unregistered_side_effect_op"); + return true; + } + + void getEffects(SmallVectorImpl<::mlir::SideEffects::EffectInstance< + ::mlir::TestEffects::Effect>> &effects) { + testSideEffectOpGetEffect(op, effects); + } + +private: + Operation *op; +}; + void TestDialect::initialize() { addOperations< #define GET_OP_LIST @@ -176,6 +199,16 @@ #include "TestTypeDefs.cpp.inc" >(); allowUnknownOperations(); + + // Instanciate our fallback op interface that we'll use on specific + // unregistered op. + fallbackEffectOpInterfaces = + new TestEffectOpInterface::Model(); +} +TestDialect::~TestDialect() { + delete static_cast< + TestEffectOpInterface::Model *>( + fallbackEffectOpInterfaces); } Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, @@ -183,6 +216,15 @@ return builder.create(loc, type, value); } +void *TestDialect::getRegisteredInterfaceForOp(mlir::TypeID typeID, + mlir::Operation *op) { + StringRef opName = op->getName().getStringRef(); + if (opName == "test.unregistered_side_effect_op" && + typeID == TypeID::get()) + return fallbackEffectOpInterfaces; + return nullptr; +} + static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; @@ -726,12 +768,22 @@ struct TestResource : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; + +void testSideEffectOpGetEffect( + Operation *op, SmallVectorImpl<::mlir::SideEffects::EffectInstance< + ::mlir::TestEffects::Effect>> &effects) { + auto effectsAttr = op->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); +} } // end anonymous namespace void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. - ArrayAttr effectsAttr = (*this)->getAttrOfType("effects"); + ArrayAttr effectsAttr = this->getAttrOfType("effects"); if (!effectsAttr) return; @@ -766,11 +818,7 @@ void SideEffectOp::getEffects( SmallVectorImpl &effects) { - auto effectsAttr = (*this)->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); + testSideEffectOpGetEffect(getOperation(), effects); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -27,6 +27,13 @@ let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; + let hasOperationInterfaceFallback = 1; + code extraClassDeclaration = [{ + ~TestDialect(); + private: + // Storage for a custom fallback interface. + void *fallbackEffectOpInterfaces; + }]; } class TEST_Op traits = []> : diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -141,6 +141,13 @@ ::mlir::NamedAttribute attribute) override; )"; +/// The code block for the op interface fallback hook. +static const char *const operationInterfaceFallbackDecl = R"( + /// Provides a hook for op interface. + void *getRegisteredInterfaceForOp(TypeID interfaceID, + Operation *op) override; +)"; + /// Generate the declaration for the given dialect class. static void emitDialectDecl(Dialect &dialect, iterator_range dialectAttrs, @@ -179,6 +186,8 @@ os << regionArgAttrVerifierDecl; if (dialect.hasRegionResultAttrVerify()) os << regionResultAttrVerifierDecl; + if (dialect.hasOperationInterfaceFallback()) + os << operationInterfaceFallbackDecl; if (llvm::Optional extraDecl = dialect.getExtraClassDeclaration()) os << *extraDecl;