diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -384,6 +384,9 @@ - Additional C++ code that is generated in the declaration of the interface class. This allows for defining methods and more on the user facing interface class, that do not need to hook into the IR entity. + These declarations are _not_ implicitly visible in default + implementations of interface methods, but static declarations may be + accessed with full name qualification. `OpInterface` classes may additionally contain the following: @@ -430,6 +433,8 @@ - `ConcreteAttr`/`ConcreteOp`/`ConcreteType` is an implicitly defined `typename` that can be used to refer to the type of the derived IR entity currently being operated on. + - This may refer to static fields of the interface class using the + qualified name, e.g., `TestOpInterface::staticMethod()`. ODS also allows for generating declarations for the `InterfaceMethod`s of an operation if the operation specifies the interface with diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -78,7 +78,7 @@ ]; let verify = [{ - auto concreteOp = cast($_op); + auto concreteOp = cast($_op); for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { Optional operands = concreteOp.getSuccessorOperands(i); if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands))) @@ -154,7 +154,7 @@ ]; let verify = [{ - static_assert(!ConcreteOpType::template hasTrait(), + static_assert(!ConcreteOp::template hasTrait(), "expected operation to have non-zero regions"); return success(); }]; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -241,17 +241,10 @@ os << " };\n"; } + // Emit the template for the external model. os << " template\n"; os << " class ExternalModel : public FallbackModel {\n"; - - // Emit the template for the external model if there are no extra class - // declarations. - if (interface.getExtraClassDeclaration()) { - os << " };\n"; - return; - } - os << " public:\n"; // Emit declarations for methods that have default implementations. Other @@ -345,9 +338,6 @@ } // Emit default implementations for the external model. - if (interface.getExtraClassDeclaration()) - return; - for (auto &method : interface.getMethods()) { if (!method.getDefaultImplementation()) continue; @@ -427,11 +417,6 @@ os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - - // Emit a utility wrapper trait class. - os << llvm::formatv(" template \n" - " struct Trait : public {0}Trait<{1}> {{};\n", - interfaceName, valueTemplate); } void InterfaceGenerator::emitInterfaceDecl(Interface interface) { @@ -452,7 +437,13 @@ << "struct " << interfaceTraitsName << " {\n"; emitConceptDecl(interface); emitModelDecl(interface); - os << "};\n} // end namespace detail\n"; + os << "};"; + + // Emit the derived trait for the interface. + os << "template \n"; + os << "struct " << interface.getName() << "Trait;\n"; + + os << "\n} // end namespace detail\n"; // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" @@ -461,8 +452,10 @@ interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); - // Emit the derived trait for the interface. - emitTraitDecl(interface, interfaceName, interfaceTraitsName); + // Emit a utility wrapper trait class. + os << llvm::formatv(" template \n" + " struct Trait : public detail::{0}Trait<{1}> {{};\n", + interfaceName, valueTemplate); // Insert the method declarations. bool isOpInterface = isa(interface); @@ -479,6 +472,10 @@ os << "};\n"; + os << "namespace detail {\n"; + emitTraitDecl(interface, interfaceName, interfaceTraitsName); + os << "}// namespace detail\n"; + emitModelMethodsDef(interface); for (StringRef ns : llvm::reverse(namespaces))