diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -36,7 +36,7 @@ which an Operation would be created (e.g., as used in Operation::create) and the regions of the op. }], - /*retTy=*/"LogicalResult", + /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", /*args=*/(ins "::mlir::MLIRContext *":$context, "::llvm::Optional<::mlir::Location>":$location, diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -290,11 +290,16 @@ // Generates the traits used by the object. void genTraits(); - // Generate the OpInterface methods. + // Generate the OpInterface methods for all interfaces. void genOpInterfaceMethods(); - // Generate op interface method. - void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait); + // Generate op interface method for the given interface. + void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait); + + // Generate op interface method for the given interface method. If + // 'declaration' is true, generates a declaration, else a definition. + OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, + bool declaration = true); // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); @@ -1588,7 +1593,7 @@ } } -void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) { +void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) { auto interface = opTrait->getOpInterface(); // Get the set of methods that should always be declared. @@ -1606,23 +1611,29 @@ if (method.getDefaultImplementation() && !alwaysDeclaredMethods.count(method.getName())) continue; - - SmallVector paramList; - for (const InterfaceMethod::Argument &arg : method.getArguments()) - paramList.emplace_back(arg.type, arg.name); - - auto properties = method.isStatic() ? OpMethod::MP_StaticDeclaration - : OpMethod::MP_Declaration; - opClass.addMethodAndPrune(method.getReturnType(), method.getName(), - properties, std::move(paramList)); + genOpInterfaceMethod(method); } } +OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, + bool declaration) { + SmallVector paramList; + for (const InterfaceMethod::Argument &arg : method.getArguments()) + paramList.emplace_back(arg.type, arg.name); + + auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None; + if (declaration) + properties = + static_cast(properties | OpMethod::MP_Declaration); + return opClass.addMethodAndPrune(method.getReturnType(), method.getName(), + properties, std::move(paramList)); +} + void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { if (const auto *opTrait = dyn_cast(&trait)) if (opTrait->shouldDeclareMethods()) - genOpInterfaceMethod(opTrait); + genOpInterfaceMethods(opTrait); } } @@ -1727,18 +1738,19 @@ void OpEmitter::genTypeInterfaceMethods() { if (!op.allResultTypesKnown()) return; - - SmallVector paramList; - paramList.emplace_back("::mlir::MLIRContext *", "context"); - paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location"); - paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::mlir::DictionaryAttr", "attributes"); - paramList.emplace_back("::mlir::RegionRange", "regions"); - paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&", - "inferredReturnTypes"); - auto *method = - opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes", - OpMethod::MP_Static, std::move(paramList)); + // Generate 'inferReturnTypes' method declaration using the interface method + // declared in 'InferTypeOpInterface' op interface. + const auto *trait = dyn_cast( + op.getTrait("::mlir::InferTypeOpInterface::Trait")); + auto interface = trait->getOpInterface(); + OpMethod *method = nullptr; + for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { + if (interfaceMethod.getName() == "inferReturnTypes") { + method = genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); + break; + } + } + assert(method != nullptr); auto &body = method->body(); body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";