diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -634,6 +634,13 @@ [DeclareOpInterfaceMethods]> { ... } ``` +Once the interfaces have been defined, the C++ header and source files can be +generated using the `--gen--interface-decls` and +`--gen--interface-defs` options with mlir-tblgen. Note that when +generating interfaces, mlir-tblgen will only generate interfaces defined in +the top-level input `.td` file. This means that any interfaces that are +defined within include files will not be considered for generation. + Note: Existing operation interfaces defined in C++ can be accessed in the ODS framework via the `OpInterfaceTrait` class. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -62,12 +62,19 @@ /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector -getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { +getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, + StringRef name) { std::vector defs = - recordKeeper.getAllDerivedDefinitions("OpInterface"); - - llvm::erase_if(defs, [](const llvm::Record *def) { - return def->isSubClassOf("DeclareOpInterfaceMethods"); + recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); + + std::string declareName = ("Declare" + name + "InterfaceMethods").str(); + llvm::erase_if(defs, [&](const llvm::Record *def) { + // Ignore any "declare methods" interfaces. + if (def->isSubClassOf(declareName)) + return true; + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); }); return defs; } @@ -110,8 +117,7 @@ /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; @@ -125,7 +131,7 @@ /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; @@ -140,8 +146,7 @@ /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType";