diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -28,6 +28,15 @@ InterfaceMethod<"Prints the type name.", "void", "printTypeC", (ins "Location":$loc) >, + // It should be possible to use the interface type name as result type + // as well as in the implementation. + InterfaceMethod<"Prints the type name and returns the type as interface.", + "TestTypeInterface", "printTypeRet", (ins "Location":$loc), + [{}], /*defaultImplementation=*/[{ + emitRemark(loc) << $_type << " - TestRet"; + return $_type; + }] + >, ]; let extraClassDeclaration = [{ /// Prints the type name. diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp --- a/mlir/test/lib/IR/TestInterfaces.cpp +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -25,6 +25,10 @@ testInterface.printTypeB(op->getLoc()); testInterface.printTypeC(op->getLoc()); testInterface.printTypeD(op->getLoc()); + // Just check that we can assign the result to a variable of interface + // type. + TestTypeInterface result = testInterface.printTypeRet(op->getLoc()); + (void)result; } if (auto testType = type.dyn_cast()) testType.printTypeE(op->getLoc()); diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir --- a/mlir/test/mlir-tblgen/interfaces.mlir +++ b/mlir/test/mlir-tblgen/interfaces.mlir @@ -4,6 +4,7 @@ // expected-remark@below {{'!test.test_type' - TestB}} // expected-remark@below {{'!test.test_type' - TestC}} // expected-remark@below {{'!test.test_type' - TestD}} +// expected-remark@below {{'!test.test_type' - TestRet}} // expected-remark@below {{'!test.test_type' - TestE}} %foo0 = "foo.test"() : () -> (!test.test_type) diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -41,9 +41,11 @@ // DECL-LABEL: TestOpInterfaceInterfaceTraits // DECL: class TestOpInterface : public ::mlir::OpInterface + // DECL: int foo(int input); -// DECL-NOT: TestOpInterface +// DECL: template +// DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo // OP_DECL-LABEL: class DeclareMethodsOp : public // OP_DECL: int foo(int input); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -82,6 +82,7 @@ void emitConceptDecl(Interface &interface); void emitModelDecl(Interface &interface); + void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); void emitInterfaceDecl(Interface interface); @@ -217,11 +218,25 @@ // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " static "); + emitCPPType(method.getReturnType(), os << " static inline "); emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); - os << " {\n "; + os << ";\n"; + } + os << " };\n"; +} + +void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { + for (auto &method : interface.getMethods()) { + os << "template\n"; + emitCPPType(method.getReturnType(), os); + os << "detail::" << interface.getName() << "InterfaceTraits::Model<" + << valueTemplate << ">::"; + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << " {\n "; // Check for a provided body to the function. if (Optional body = method.getBody()) { @@ -229,7 +244,7 @@ os << body->trim(); else os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); - os << "\n }\n"; + os << "\n}\n"; continue; } @@ -244,9 +259,8 @@ llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); - os << ");\n }\n"; + os << ");\n}\n"; } - os << " };\n"; } void InterfaceGenerator::emitTraitDecl(Interface &interface, @@ -308,6 +322,10 @@ StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); + // Emit a forward declaration of the interface class so that it becomes usable + // in the signature of its methods. + os << "class " << interfaceName << ";\n"; + // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" << "struct " << interfaceTraitsName << " {\n"; @@ -340,6 +358,8 @@ os << "};\n"; + emitModelMethodsDef(interface); + for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; }