diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2315,11 +2315,11 @@ let results = (outs AnyType:$result); } -// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface. -// Tests suppression of ambiguous build methods for operations with -// SameOperandsAndResultType and InferTypeOpInterface. -def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", - [SameOperandsAndResultType, InferTypeOpInterface]> { +// Base class for testing `build` methods for ops with +// InferReturnTypeOpInterface. +class TableGenBuildInferReturnTypeBaseOp traits = []> + : TEST_Op { let arguments = (ins Variadic:$inputs); let results = (outs AnyType:$result); @@ -2334,6 +2334,18 @@ }]; } +// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface. +// Tests suppression of ambiguous build methods for operations with +// SameOperandsAndResultType and InferTypeOpInterface. +def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp< + "tblgen_build_5", [SameOperandsAndResultType]>; + +// Op with InferTypeOpInterface and regions. +def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp< + "tblgen_build_6", [InferTypeOpInterface]> { + let regions = (region AnyRegion:$body); +} + //===----------------------------------------------------------------------===// // Test BufferPlacement //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1220,8 +1220,7 @@ } static bool canInferType(Operator &op) { - return op.getTrait("::mlir::InferTypeOpInterface::Trait") && - op.getNumRegions() == 0; + return op.getTrait("::mlir::InferTypeOpInterface::Trait"); } void OpEmitter::genSeparateArgParamBuilder() { @@ -1304,7 +1303,7 @@ // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op)) + if (canInferType(op) && op.getNumRegions() == 0) emit(attrType, TypeParamKind::None, /*inferType=*/true); emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } @@ -1392,21 +1391,22 @@ // Result types body << formatv(R"( - ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; - if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, operands, - {1}.attributes.getDictionary({1}.getContext()), - /*regions=*/{{}, inferredReturnTypes))) {{)", + ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; + if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), + {1}.location, operands, + {1}.attributes.getDictionary({1}.getContext()), + {1}.regions, inferredReturnTypes))) {{)", opClass.getClassName(), builderOpState); if (numVariadicResults == 0 || numNonVariadicResults != 0) - body << " assert(inferredReturnTypes.size()" + body << "\n assert(inferredReturnTypes.size()" << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults - << "u && \"mismatched number of return types\");\n"; - body << " " << builderOpState << ".addTypes(inferredReturnTypes);"; + << "u && \"mismatched number of return types\");"; + body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);"; body << formatv(R"( - } else - ::llvm::report_fatal_error("Failed to infer result type(s).");)", + } else {{ + ::llvm::report_fatal_error("Failed to infer result type(s)."); + })", opClass.getClassName(), builderOpState); } @@ -1606,7 +1606,7 @@ body << " " << builderOpState << ".addTypes(resultTypes);\n"; // Generate builder that infers type too. - // TODO: Expand to handle regions and successors. + // TODO: Expand to handle successors. if (canInferType(op) && op.getNumSuccessors() == 0) genInferredTypeCollectiveParamBuilder(); } diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -219,4 +219,11 @@ testSingleVariadicInputInferredType(); } +TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) { + auto op = builder.create( + loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs); + ASSERT_EQ(op->getNumRegions(), 1u); + verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstF32}, noAttrs); +} + } // namespace mlir