diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -76,15 +76,16 @@ // DECL: /// some function comment // DECL: int foo(int input); -// DECL: template -// DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo - // DECL-LABEL: struct TestOpInterfaceVerifyTrait // DECL: verifyTrait // DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait // DECL: verifyRegionTrait +// Method implementations come last, after all class definitions. +// DECL: template +// DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo + // OP_DECL-LABEL: class DeclareMethodsOp : public // OP_DECL: int foo(int input); // OP_DECL-NOT: int default_foo(int input); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -289,6 +289,11 @@ } void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + for (auto &method : interface.getMethods()) { os << "template\n"; emitCPPType(method.getReturnType(), os); @@ -384,6 +389,9 @@ method.isStatic() ? &ctx : &nonStaticMethodFmt); os << "\n}\n"; } + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitTraitDecl(const Interface &interface, @@ -498,8 +506,6 @@ emitTraitDecl(interface, interfaceName, interfaceTraitsName); os << "}// namespace detail\n"; - emitModelMethodsDef(interface); - for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } @@ -507,8 +513,10 @@ bool InterfaceGenerator::emitInterfaceDecls() { llvm::emitSourceFileHeader("Interface Declarations", os); - for (const auto *def : defs) + for (const llvm::Record *def : defs) emitInterfaceDecl(Interface(def)); + for (const llvm::Record *def : defs) + emitModelMethodsDef(Interface(def)); return false; }