diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1511,7 +1511,6 @@ // Generate builder that infers type too. // TODO: Subsume this with general checking if type can be // inferred automatically. - // TODO: Expand to handle regions. body << formatv(R"( ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), @@ -1588,7 +1587,7 @@ // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op) && op.getNumRegions() == 0) + if (canInferType(op)) emit(attrType, TypeParamKind::None, /*inferType=*/true); emit(attrType, TypeParamKind::Collective, /*inferType=*/false); }