diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -31,7 +31,7 @@ // CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(), // CHECK: odsState.location, odsState.operands, // CHECK: odsState.attributes.getDictionary(odsState.getContext()), -// CHECK: /*regions=*/{}, inferredReturnTypes))) +// CHECK: odsState.regions, inferredReturnTypes))) // CHECK: odsState.addTypes(inferredReturnTypes); def OpC : NS_Op<"three_normal_result_op", []> { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1583,13 +1583,12 @@ // Generate builder that infers type too. // TODO: Subsume this with general checking if type can be // inferred automatically. - // TODO: Expand to handle regions. body << formatv(R"( ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), {1}.location, {1}.operands, {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) + {1}.regions, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else ::llvm::report_fatal_error("Failed to infer result type(s).");)", @@ -1660,7 +1659,7 @@ // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op) && op.getNumRegions() == 0) + if (canInferType(op)) emit(attrType, TypeParamKind::None, /*inferType=*/true); emit(attrType, TypeParamKind::Collective, /*inferType=*/false); }