diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -189,17 +189,19 @@ //===----------------------------------------------------------------------===// void InterfaceGenerator::emitConceptDecl(Interface &interface) { - os << " class Concept {\n" - << " public:\n" - << " virtual ~Concept() = default;\n"; + os << " struct Concept {\n"; // Insert each of the pure virtual concept methods. for (auto &method : interface.getMethods()) { - os << " virtual "; + os << " "; emitCPPType(method.getReturnType(), os); - emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), /*addConst=*/true); - os << " = 0;\n"; + os << "(*" << method.getName() << ")("; + if (!method.isStatic()) + emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); + llvm::interleaveComma( + method.getArguments(), os, + [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); + os << ");\n"; } os << " };\n"; } @@ -207,13 +209,19 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { os << " template\n"; os << " class Model : public Concept {\n public:\n"; + os << " Model() : Concept{"; + llvm::interleaveComma( + interface.getMethods(), os, + [&](const InterfaceMethod &method) { os << method.getName(); }); + os << "} {}\n\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " "); + emitCPPType(method.getReturnType(), os << " static "); emitMethodNameAndArgs(method, os, valueType, - /*addThisArg=*/!method.isStatic(), /*addConst=*/true); - os << " final {\n "; + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << " {\n "; // Check for a provided body to the function. if (Optional body = method.getBody()) {