diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -67,7 +67,7 @@ }; /// Register the interface with the dialect. -AffineDialect::AffineDialect(MLIRContext *context) ... { +void AffineDialect::initialize() { addInterfaces(); } ``` @@ -204,6 +204,71 @@ llvm::errs() << "hook returned = " << example.exampleInterfaceHook() << "\n"; ``` +#### Fallback Dialect Interfaces for Unregistered Operations + +The above mechanism requires that the IR entity be registered to implement the +interface, but some dialects however, allow for unregistered operations. To +support operation interfaces in such situations, a derived `OpInterface` may +provide a dialect interface (named `DialectInterface`) that provides a fallback +implementation. This interface acts similarly to the `Model` traits class, but +must also include an additional hook: `bool isOpSupported(OperationName op) +const` that filters which operations are supported by the fallback. An example +is shown below: + +```c++ +/// Define the main interface class that analyses and transformations will +/// interface with. +class ExampleOpInterface : public OpInterface { +public: + /// ... The details from the earlier example are elided here ... + + /// This class represents a dialect fallback implementation for unregistered + /// operations. + class DialectInterface + : public mlir::DialectInterface::Base, + ExampleOpInterfaceTraits::Concept { + public: + /// Returns true if the given unregistered operation is supported by this + /// fallback interface. + virtual bool isOpSupported(OperationName op) const = 0; + }; +}; + + +/// This class defines a fallback dialect interface for handling +/// `ExampleOpInterface`. +struct ExampleDialectFallbackInterface + : public ExampleOpInterface::DialectInterface { + using ExampleOpInterface::DialectInterface::DialectInterface; + + /// Returns true if the given unregistered operation is supported by this + /// fallback interface. + bool isOpSupported(OperationName op) const final { + return op.getStringRef() == "example.unregistered_op"; + } + + /// Implementations of the interface concept hooks. + unsigned exampleInterfaceHook(Operation *op) const final { + // ... + } + unsigned exampleStaticInterfaceHook(OperationName op) const final { + // ... + } +}; + +/// Register the interface with the dialect. +void ExampleDialect::initialize() { + addInterfaces(); +} + +/// The use of a fallback interface is transparent to users, with no changes +/// required for the user facing code. +Operation *op = ...; +if (ExampleOpInterface example = dyn_cast(op)) + // ... ; +``` + #### Utilizing the ODS Framework Note: Before reading this section, the reader should have some familiarity with @@ -260,6 +325,10 @@ `OpInterface` classes may additionally contain the following: +* GenerateDialectFallback (`generateDialectFallback`) + - Signal that a dialect interface fallback class should be generated for + this interface, to allow for dialects to provide an implementation for + unregistered operations. * Verifier (`verify`) - A C++ code block containing additional verification applied to the operation that the interface is attached to. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -849,7 +849,8 @@ [{ BlockAndValueMapping map; unsigned numRegions = $_op->getNumRegions(); - Operation *res = create(b, loc, resultTypes, operands, $_op.getAttrs()); + Operation *res = b.create(loc, resultTypes, operands, + $_op.getAttrs()); assert(res->getNumRegions() == numRegions && "inconsistent # regions"); for (unsigned ridx = 0; ridx < numRegions; ++ridx) $_op->getRegion(ridx).cloneInto( diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1828,6 +1828,12 @@ // An optional code block containing extra declarations to place in the // interface trait declaration. code extraTraitClassDeclaration = ""; + + // Specify that a fallback dialect interface should be generated for this + // interface. If set, a dialect interface class named `DialectInterface` will + // be generated within the interface class enabling for dialects to provide a + // fallback for unregistered operations. + bit generateDialectFallback = 0; } // This class represents a single, optionally static, interface method. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -19,6 +19,7 @@ #ifndef MLIR_IR_OPDEFINITION_H #define MLIR_IR_OPDEFINITION_H +#include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -1738,11 +1739,44 @@ using InterfaceBase::InterfaceBase; protected: + /// Trait to check if T provides a 'DialectInterface' class. + template + using has_dialect_interface = + decltype(std::declval()); + template + using detect_has_dialect_interface = + llvm::is_detected; + /// Returns the impl interface instance for the given operation. - static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { + static const typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { // Access the raw interface from the abstract operation. - auto *abstractOp = op->getAbstractOperation(); - return abstractOp ? abstractOp->getInterface() : nullptr; + if (auto *abstractOp = op->getAbstractOperation()) + return abstractOp->getInterface(); + return getDialectInterfaceFor(op); + } + /// This overload is selected when the interface does not have a dialect + /// interface fallback. + template + static std::enable_if_t::value, + const typename InterfaceT::Concept *> + getDialectInterfaceFor(Operation *op) { + return nullptr; + } + /// This overload is selected when the interface does have a dialect interface + /// fallback. + template + static std::enable_if_t::value, + const typename InterfaceT::Concept *> + getDialectInterfaceFor(Operation *op) { + // If this operation isn't registered, try the dialect as a fallback. + MLIRContext *ctx = op->getContext(); + if (Dialect *dialect = ctx->getLoadedDialect(op->getName().getDialect())) { + auto *interface = dialect->getRegisteredInterface< + typename ConcreteType::DialectInterface>(); + if (interface && interface->isOpSupported(op->getName())) + return interface; + } + return nullptr; } /// Allow access to `getInterfaceFor`. diff --git a/mlir/include/mlir/IR/RegionKindInterface.td b/mlir/include/mlir/IR/RegionKindInterface.td --- a/mlir/include/mlir/IR/RegionKindInterface.td +++ b/mlir/include/mlir/IR/RegionKindInterface.td @@ -31,23 +31,21 @@ let methods = [ StaticInterfaceMethod< /*desc=*/[{ - Return the kind of the region with the given index inside this operation. + Return the kind of the region with the given index inside this + operation. }], /*retTy=*/"RegionKind", /*methodName=*/"getRegionKind", /*args=*/(ins "unsigned":$index) >, - StaticInterfaceMethod< - /*desc=*/"Return true if the kind of the given region requires the " - "SSA-Dominance property", - /*retTy=*/"bool", - /*methodName=*/"hasSSADominance", - /*args=*/(ins "unsigned":$index), - /*methodBody=*/[{ - return getRegionKind(index) == RegionKind::SSACFG; - }] - >, ]; + let extraClassDeclaration = [{ + /// Return true if the kind of the given region requires the SSA-Dominance + /// property. + bool hasSSADominance(unsigned index) { + return getRegionKind(index) == RegionKind::SSACFG; + } + }]; } #endif // MLIR_IR_REGIONKINDINTERFACE diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -14,6 +14,7 @@ #ifndef MLIR_INTERFACES_SIDEEFFECTS_H #define MLIR_INTERFACES_SIDEEFFECTS_H +#include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpDefinition.h" namespace mlir { diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -31,6 +31,9 @@ an operation. }]; let cppNamespace = "::mlir"; + + // Generate a fallback dialect interface for providing memory effects. + let generateDialectFallback = 1; } // The base class for defining specific memory effects. diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -108,11 +108,10 @@ protected: /// Get the raw concept in the correct derived concept type. const Concept *getImpl() const { return impl; } - Concept *getImpl() { return impl; } private: /// A pointer to the impl concept object. - Concept *impl; + const Concept *impl; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h --- a/mlir/include/mlir/TableGen/Interfaces.h +++ b/mlir/include/mlir/TableGen/Interfaces.h @@ -94,6 +94,10 @@ // Return the verify method body if it has one. llvm::Optional getVerify() const; + // Returns true if this interface requested that a dialect interface fallback + // be generated. + bool shouldGenerateDialectFallback() const; + // Returns the Tablegen definition this interface was constructed from. const llvm::Record &getDef() const { return *def; } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -977,7 +977,7 @@ // TODO: Handle failure. SmallVector inferredTypes; if (failed(concept->inferReturnTypes( - state.getContext(), state.location, state.operands, + state.name, state.getContext(), state.location, state.operands, state.attributes.getDictionary(state.getContext()), state.regions, inferredTypes))) return; diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp --- a/mlir/lib/TableGen/Interfaces.cpp +++ b/mlir/lib/TableGen/Interfaces.cpp @@ -119,6 +119,15 @@ return value.empty() ? llvm::Optional() : value; } +// Returns true if this interface requested that a dialect interface fallback +// be generated. +bool Interface::shouldGenerateDialectFallback() const { + // Only OpInterface supports the generateDialectFallback method. + if (!isa(this)) + return false; + return def->getValueAsBit("generateDialectFallback"); +} + //===----------------------------------------------------------------------===// // AttrInterface //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/test-side-effects.mlir b/mlir/test/IR/test-side-effects.mlir --- a/mlir/test/IR/test-side-effects.mlir +++ b/mlir/test/IR/test-side-effects.mlir @@ -30,3 +30,30 @@ %4 = "test.side_effect_op"() { effect_parameter = affine_map<(i, j) -> (j, i)> } : () -> i32 + +/// The following tests the same as above, but on an unregistered operation in +/// the test dialect. + +// expected-remark@+1 {{operation has no memory effects}} +%5 = "test.unknown_side_effect_op"() {} : () -> i32 + +// expected-remark@+2 {{found an instance of 'read' on resource ''}} +// expected-remark@+1 {{found an instance of 'free' on resource ''}} +%6 = "test.unknown_side_effect_op"() {effects = [ + {effect="read"}, {effect="free"} +]} : () -> i32 + +// expected-remark@+1 {{found an instance of 'write' on resource ''}} +%7 = "test.unknown_side_effect_op"() {effects = [ + {effect="write", test_resource} +]} : () -> i32 + +// expected-remark@+1 {{found an instance of 'allocate' on a value, on resource ''}} +%8 = "test.unknown_side_effect_op"() {effects = [ + {effect="allocate", on_result, test_resource} +]} : () -> i32 + +// expected-remark@+1 {{found an instance of 'read' on a symbol '@foo_ref', on resource ''}} +"test.unknown_side_effect_op"() {effects = [ + {effect="read", on_reference = @foo_ref, test_resource} +]} : () -> i32 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -29,57 +29,13 @@ // TestDialect Interfaces //===----------------------------------------------------------------------===// -namespace { - -// Test support for interacting with the AsmPrinter. -struct TestOpAsmInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - - LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { - StringAttr strAttr = attr.dyn_cast(); - if (!strAttr) - return failure(); - - // Check the contents of the string attribute to see what the test alias - // should be named. - Optional aliasName = - StringSwitch>(strAttr.getValue()) - .Case("alias_test:dot_in_name", StringRef("test.alias")) - .Case("alias_test:trailing_digit", StringRef("test_alias0")) - .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) - .Case("alias_test:sanitize_conflict_a", - StringRef("test_alias_conflict0")) - .Case("alias_test:sanitize_conflict_b", - StringRef("test_alias_conflict0_")) - .Default(llvm::None); - if (!aliasName) - return failure(); - - os << *aliasName; - return success(); - } - - void getAsmResultNames(Operation *op, - OpAsmSetValueNameFn setNameFn) const final { - if (auto asmOp = dyn_cast(op)) - setNameFn(asmOp, "result"); - } - - void getAsmBlockArgumentNames(Block *block, - OpAsmSetValueNameFn setNameFn) const final { - auto op = block->getParentOp(); - auto arrayAttr = op->getAttrOfType("arg_names"); - if (!arrayAttr) - return; - auto args = block->getArguments(); - auto e = std::min(arrayAttr.size(), args.size()); - for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = arrayAttr[i].dyn_cast()) - setNameFn(args[i], strAttr.getValue()); - } - } -}; +/// Collect a set of side effects for the given operation, whose side effects +/// are defined in an attached `effects` ArrayAttr. +static void +getTestSideEffectsFor(Operation *op, + SmallVectorImpl &effects); +namespace { struct TestDialectFoldInterface : public DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; @@ -92,9 +48,8 @@ } }; -/// This class defines the interface for handling inlining with standard -/// operations. -struct TestInlinerInterface : public DialectInlinerInterface { +/// This class defines the interface for handling inlining with test operations. +struct TestDialectInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// @@ -158,6 +113,77 @@ return builder.create(conversionLoc, resultType, input); } }; + +/// Test support for interacting with the AsmPrinter. +struct TestDialectOpAsmInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { + StringAttr strAttr = attr.dyn_cast(); + if (!strAttr) + return failure(); + + // Check the contents of the string attribute to see what the test alias + // should be named. + Optional aliasName = + StringSwitch>(strAttr.getValue()) + .Case("alias_test:dot_in_name", StringRef("test.alias")) + .Case("alias_test:trailing_digit", StringRef("test_alias0")) + .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) + .Case("alias_test:sanitize_conflict_a", + StringRef("test_alias_conflict0")) + .Case("alias_test:sanitize_conflict_b", + StringRef("test_alias_conflict0_")) + .Default(llvm::None); + if (!aliasName) + return failure(); + + os << *aliasName; + return success(); + } + + void getAsmResultNames(Operation *op, + OpAsmSetValueNameFn setNameFn) const final { + if (auto asmOp = dyn_cast(op)) + setNameFn(asmOp, "result"); + } + + void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const final { + auto op = block->getParentOp(); + auto arrayAttr = op->getAttrOfType("arg_names"); + if (!arrayAttr) + return; + auto args = block->getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = arrayAttr[i].dyn_cast()) + setNameFn(args[i], strAttr.getValue()); + } + } +}; + +/// This class defines the fallback dialect interface for handling operation +/// side effects. +struct TestDialectSideEffectInterface + : public MemoryEffectOpInterface::DialectInterface { + using MemoryEffectOpInterface::DialectInterface::DialectInterface; + + /// Returns true if the given unregistered operation is supported by this + /// fallback implementation. + bool isOpSupported(OperationName op) const override { + return op.getStringRef() == "test.unknown_side_effect_op"; + } + + /// Collects all of the operation's effects into `effects`. + void + getEffects(Operation *op, + SmallVectorImpl> + &effects) const override { + assert(isOpSupported(op->getName()) && "expected supported operation"); + return getTestSideEffectsFor(op, effects); + } +}; } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -169,8 +195,8 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addInterfaces(); + addInterfaces(); addTypes &effects) { +/// Collect a set of side effects for the given operation, whose side effects +/// are defined in an attached `effects` ArrayAttr. +static void +getTestSideEffectsFor(Operation *op, + SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. - ArrayAttr effectsAttr = getAttrOfType("effects"); + ArrayAttr effectsAttr = op->getAttrOfType("effects"); if (!effectsAttr) return; @@ -751,7 +780,7 @@ // Check for a result to affect. if (effectElement.get("on_result")) - effects.emplace_back(effect, getResult(), resource); + effects.emplace_back(effect, op->getResult(0), resource); else if (Attribute ref = effectElement.get("on_reference")) effects.emplace_back(effect, ref.cast(), resource); else @@ -759,6 +788,11 @@ } } +void SideEffectOp::getEffects( + SmallVectorImpl &effects) { + getTestSideEffectsFor(*this, effects); +} + void SideEffectOp::getEffects( SmallVectorImpl &effects) { auto effectsAttr = getAttrOfType("effect_parameter"); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -36,25 +36,6 @@ return os; } -/// Emit the method name and argument list for the given method. If 'addThisArg' -/// is true, then an argument is added to the beginning of the argument list for -/// the concrete value. -static void emitMethodNameAndArgs(const InterfaceMethod &method, - raw_ostream &os, StringRef valueType, - bool addThisArg, bool addConst) { - os << method.getName() << '('; - if (addThisArg) - emitCPPType(valueType, os) - << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); - llvm::interleaveComma(method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); - os << ')'; - if (addConst) - os << " const"; -} - /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector @@ -72,6 +53,8 @@ /// This struct is the base generator used when processing tablegen interfaces. class InterfaceGenerator { public: + using Argument = InterfaceMethod::Argument; + bool emitInterfaceDefs(); bool emitInterfaceDecls(); bool emitInterfaceDocs(); @@ -80,12 +63,28 @@ InterfaceGenerator(std::vector &&defs, raw_ostream &os) : defs(std::move(defs)), os(os) {} + /// Emit the function type used to store the given method. + void emitMethodType(const InterfaceMethod &method); + + /// Emit a call to the given method using the names of its arguments as + /// parameters. + void emitCallToMethod(const InterfaceMethod &method, bool addThisArg); + + /// Emit the argument list for the given method. If 'addThisArg' is true, then + /// an argument is added to the beginning of the argument list for the + /// concrete value. + void emitMethodArgs(const InterfaceMethod &method, bool addThisArg, + bool addConst); + + void emitInterfaceDef(Interface interface); void emitConceptDecl(Interface &interface); void emitModelDecl(Interface &interface); void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); + void emitDialectInterfaceModelDecl(Interface &interface); void emitInterfaceDecl(Interface interface); + void emitInterfaceDoc(Interface interface); /// The set of interface records to emit. std::vector defs; @@ -93,6 +92,9 @@ raw_ostream &os; /// The C++ value type of the interface, e.g. Operation*. StringRef valueType; + /// The C++ value type of the interface used when the method is static, e.g. + /// OperationName. + StringRef staticValueType; /// The C++ base interface type. StringRef interfaceBaseType; /// The name of the typename for the value template. @@ -108,6 +110,7 @@ : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), os) { valueType = "::mlir::Attribute"; + staticValueType = valueType; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; StringRef castCode = "(tablegen_opaque_val.cast())"; @@ -121,6 +124,7 @@ OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { valueType = "::mlir::Operation *"; + staticValueType = "::mlir::OperationName"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; StringRef castCode = "(llvm::cast(tablegen_opaque_val))"; @@ -134,6 +138,7 @@ : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), os) { valueType = "::mlir::Type"; + staticValueType = valueType; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType"; StringRef castCode = "(tablegen_opaque_val.cast())"; @@ -144,12 +149,44 @@ }; } // end anonymous namespace +void InterfaceGenerator::emitMethodType(const InterfaceMethod &method) { + os << method.getReturnType() << "(" + << (method.isStatic() ? staticValueType : valueType) + << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma(method.getArguments(), os, + [&](const Argument &arg) { os << arg.type; }); + os << ")"; +} + +void InterfaceGenerator::emitCallToMethod(const InterfaceMethod &method, + bool addThisArg) { + os << method.getName() << '('; + if (addThisArg) + os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma(method.getArguments(), os, + [&](const Argument &arg) { os << arg.name; }); + os << ')'; +} + +void InterfaceGenerator::emitMethodArgs(const InterfaceMethod &method, + bool addThisArg, bool addConst) { + if (addThisArg) { + emitCPPType(method.isStatic() ? staticValueType : valueType, os); + os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); + } + llvm::interleaveComma(method.getArguments(), os, [&](const Argument &arg) { + os << arg.type << " " << arg.name; + }); + os << ')'; + if (addConst) + os << " const"; +} + //===----------------------------------------------------------------------===// // GEN: Interface definitions //===----------------------------------------------------------------------===// -static void emitInterfaceDef(Interface interface, StringRef valueType, - raw_ostream &os) { +void InterfaceGenerator::emitInterfaceDef(Interface interface) { StringRef interfaceName = interface.getName(); StringRef cppNamespace = interface.getCppNamespace(); cppNamespace.consume_front("::"); @@ -160,19 +197,19 @@ emitCPPType(method.getReturnType(), os); if (!cppNamespace.empty()) os << cppNamespace << "::"; - os << interfaceName << "::"; - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, - /*addConst=*/!isOpInterface); + os << interfaceName << "::" << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); // Forward to the method on the concrete operation type. os << " {\n return getImpl()->" << method.getName() << '('; - if (!method.isStatic()) { - os << (isOpInterface ? "getOperation()" : "*this"); - os << (method.arg_empty() ? "" : ", "); - } - llvm::interleaveComma( - method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); + if (isOpInterface) + os << "getOperation()" << (method.isStatic() ? "->getName()" : ""); + else + os << "*this"; + os << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma(method.getArguments(), os, + [&](const Argument &arg) { os << arg.name; }); os << ");\n }\n"; } } @@ -181,7 +218,7 @@ llvm::emitSourceFileHeader("Interface Definitions", os); for (const auto *def : defs) - emitInterfaceDef(Interface(def), valueType, os); + emitInterfaceDef(Interface(def)); return false; } @@ -194,48 +231,42 @@ // Insert each of the pure virtual concept methods. for (auto &method : interface.getMethods()) { - os << " "; - emitCPPType(method.getReturnType(), os); - os << "(*" << method.getName() << ")("; - if (!method.isStatic()) - emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); - llvm::interleaveComma( - method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); - os << ");\n"; + os << " llvm::function_ref<"; + emitMethodType(method); + os << "> " << method.getName() << ";\n"; } os << " };\n"; } void InterfaceGenerator::emitModelDecl(Interface &interface) { - os << " template\n"; - os << " class Model : public Concept {\n public:\n"; - os << " Model() : Concept{"; - llvm::interleaveComma( - interface.getMethods(), os, - [&](const InterfaceMethod &method) { os << method.getName(); }); - os << "} {}\n\n"; + os << " template\n" + << " class Model : public Concept {\n public:\n Model();\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " static inline "); - emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), - /*addConst=*/false); + emitCPPType(method.getReturnType(), os << " static inline ") + << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/true, /*addConst=*/false); os << ";\n"; } os << " };\n"; } void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { + os << "template\n" + << "detail::" << interface.getName() << "InterfaceTraits::Model<" + << valueTemplate << ">::Model::Model() : Concept{"; + llvm::interleaveComma( + interface.getMethods(), os, + [&](const InterfaceMethod &method) { os << method.getName(); }); + os << "} {}\n"; + for (auto &method : interface.getMethods()) { os << "template\n"; emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::Model<" - << valueTemplate << ">::"; - emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), - /*addConst=*/false); + << valueTemplate << ">::" << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/true, /*addConst=*/false); os << " {\n "; // Check for a provided body to the function. @@ -255,11 +286,8 @@ os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); // Add the arguments to the call. - os << method.getName() << '('; - llvm::interleaveComma( - method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); - os << ");\n}\n"; + emitCallToMethod(method, /*addThisArg=*/false); + os << ";\n}\n"; } } @@ -286,9 +314,9 @@ continue; os << " " << (method.isStatic() ? "static " : ""); - emitCPPType(method.getReturnType(), os); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, - /*addConst=*/!isOpInterface); + emitCPPType(method.getReturnType(), os) << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) << "\n }\n"; } @@ -313,6 +341,80 @@ interfaceName, valueTemplate); } +/// Returns true if the given method body can be reused by a dialect interface +/// fallback. Returns false if the method will need to be reimplemented by the +/// derived interface. +static bool canUseMethodBodyInDialectInterface(StringRef body, bool isStatic, + StringRef valueTemplate) { + // Check that the value template and replacements aren't used in the body. + if (body.contains(valueTemplate)) + return false; + + // Check that the replacements for the current value aren't used. + return isStatic || (!body.contains("$_self") && !body.contains("$_op")); +} + +void InterfaceGenerator::emitDialectInterfaceModelDecl(Interface &interface) { + // Emit the dialect interface definition. + os << " class DialectInterface : public " + "::mlir::DialectInterface::Base, public Concept {\n" + " public:\n" + " DialectInterface(::mlir::Dialect *dialect) : " + "Base(dialect), Concept{"; + + // We forward to an impl implemenation to allow for a virtual interface + // specifically for the dialect interface. + llvm::interleaveComma(interface.getMethods(), os, + [&](const InterfaceMethod &method) { + os << "__impl_" << method.getName(); + }); + os << "} {}\n\n"; + + // Implement the dialect interface hooks for the Concept. + for (auto &method : interface.getMethods()) { + os << " virtual "; + emitCPPType(method.getReturnType(), os) << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/true, /*addConst=*/true); + + // Check for a provided body to the function. + if (Optional body = method.getBody()) { + if (canUseMethodBodyInDialectInterface(*body, method.isStatic(), + valueTemplate)) { + os << " {\n "; + if (method.isStatic()) + os << body->trim(); + else + os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); + os << "\n}\n"; + continue; + } + } + + // Otherwise, this method must be implemented by the derived interface. + os << " = 0;\n"; + } + + // Emit the selector method that the dialect can use to filter supported + // unregistered operations. + os << " /// Returns true if the given unregistered operation is\n" + " /// supported by this fallback implementation.\n" + " virtual bool isOpSupported(::mlir::OperationName op) const = 0;\n" + << "\n private:\n"; + + // Implement the internal implementations for the Concept methods that forward + // to the dialect interface variants. + for (auto &method : interface.getMethods()) { + os << " std::function<"; + emitMethodType(method); + os << "> __impl_" << method.getName() << " = \n [this]("; + emitMethodArgs(method, /*addThisArg=*/true, /*addConst=*/false); + os << " {\n return "; + emitCallToMethod(method, /*addThisArg=*/true); + os << ";\n };\n"; + } + os << " };\n"; +} + void InterfaceGenerator::emitInterfaceDecl(Interface interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -346,9 +448,9 @@ // Insert the method declarations. bool isOpInterface = isa(interface); for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " "); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, - /*addConst=*/!isOpInterface); + emitCPPType(method.getReturnType(), os << " ") << method.getName() << '('; + emitMethodArgs(method, /*addThisArg=*/false, + /*addConst=*/!isOpInterface); os << ";\n"; } @@ -356,6 +458,8 @@ if (Optional extraDecls = interface.getExtraClassDeclaration()) os << *extraDecls << "\n"; + if (interface.shouldGenerateDialectFallback()) + emitDialectInterfaceModelDecl(interface); os << "};\n"; emitModelMethodsDef(interface); @@ -376,12 +480,10 @@ // GEN: Interface documentation //===----------------------------------------------------------------------===// -static void emitInterfaceDoc(const llvm::Record &interfaceDef, - raw_ostream &os) { - Interface interface(&interfaceDef); - +void InterfaceGenerator::emitInterfaceDoc(Interface interface) { // Emit the interface name followed by the description. - os << "## " << interface.getName() << " (" << interfaceDef.getName() << ")"; + os << "## " << interface.getName() << " (" << interface.getDef().getName() + << ")"; if (auto description = interface.getDescription()) mlir::tblgen::emitDescription(*description, os); @@ -395,11 +497,8 @@ if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os) << method.getName() << '('; - llvm::interleaveComma(method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { - emitCPPType(arg.type, os) << arg.name; - }); - os << ");\n```\n"; + emitMethodArgs(method, /*addThisArg=*/false, /*addConst=*/false); + os << ";\n```\n"; // Emit the description. if (auto description = method.getDescription()) @@ -416,7 +515,7 @@ os << "# " << interfaceBaseType << " definitions\n"; for (const auto *def : defs) - emitInterfaceDoc(*def, os); + emitInterfaceDoc(Interface(def)); return false; }