diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -78,6 +78,7 @@ ip: An InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager). + infer_type: Whether to infer result types. Returns: A new "detached" Operation object. Detached operations can be added to blocks, which causes them to become "attached." @@ -1288,7 +1289,7 @@ std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const py::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1367,6 +1368,7 @@ if (!mlirOperands.empty()) mlirOperationStateAddOperands(&state, mlirOperands.size(), mlirOperands.data()); + state.enableResultTypeInference = inferType; if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), mlirResults.data()); @@ -1398,6 +1400,8 @@ // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); + if (!operation.ptr) + throw py::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1441,51 +1445,10 @@ // PyOpView //------------------------------------------------------------------------------ -py::object -PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, - py::list operandList, std::optional attributes, - std::optional> successors, - std::optional regions, - DefaultingPyLocation location, - const py::object &maybeIp) { - PyMlirContextRef context = location->getContext(); - // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); - // Operand and result segment specs are either none, which does no - // variadic unpacking, or a list of ints with segment sizes, where each - // element is either a positive number (typically 1 for a scalar) or -1 to - // indicate that it is derived from the length of the same-indexed operand - // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - - std::vector operandSegmentLengths; - std::vector resultSegmentLengths; - - // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); - int opMinRegionCount = std::get<0>(opRegionSpec); - bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); - if (!regions) { - regions = opMinRegionCount; - } - if (*regions < opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - - // Unpack results. - std::vector resultTypes; +static void populateResultTypes(StringRef name, py::list resultTypeList, + const py::object &resultSegmentSpecObj, + std::vector &resultSegmentLengths, + std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); if (resultSegmentSpecObj.is_none()) { // Non-variadic result unpacking. @@ -1556,7 +1519,8 @@ } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the - // above "casts". Just keep the scope above small and catch them all. + // above "casts". Just keep the scope above small and catch them + // all. throw py::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Types (" + @@ -1568,6 +1532,56 @@ } } } +} + +py::object PyOpView::buildGeneric( + const py::object &cls, std::optional resultTypeList, + py::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const py::object &maybeIp) { + PyMlirContextRef context = location->getContext(); + // Class level operation construction metadata. + std::string name = py::cast(cls.attr("OPERATION_NAME")); + // Operand and result segment specs are either none, which does no + // variadic unpacking, or a list of ints with segment sizes, where each + // element is either a positive number (typically 1 for a scalar) or -1 to + // indicate that it is derived from the length of the same-indexed operand + // or result (implying that it is a list at that position). + py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; + + // Validate/determine region count. + auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + int opMinRegionCount = std::get<0>(opRegionSpec); + bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); + if (!regions) { + regions = opMinRegionCount; + } + if (*regions < opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + if (opHasNoVariadicRegions && *regions > opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + + // Unpack results. + std::vector resultTypes; + if (resultTypeList.has_value()) { + populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, + resultSegmentLengths, resultTypes); + } // Unpack operands. std::vector operands; @@ -1694,7 +1708,8 @@ /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), - /*regions=*/*regions, location, maybeIp); + /*regions=*/*regions, location, maybeIp, + !resultTypeList); } pybind11::object PyOpView::constructDerived(const pybind11::object &cls, @@ -2854,7 +2869,7 @@ py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) + py::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -655,7 +655,8 @@ std::optional> operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip); + DefaultingPyLocation location, const pybind11::object &ip, + bool inferType = false); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); @@ -704,13 +705,12 @@ pybind11::object getOperationObject() { return operationObject; } - static pybind11::object - buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, - pybind11::list operandList, - std::optional attributes, - std::optional> successors, - std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + static pybind11::object buildGeneric( + const pybind11::object &cls, std::optional resultTypeList, + pybind11::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const pybind11::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -245,14 +245,10 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { // CHECK: def __init__(self, *, loc=None, ip=None): - // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) - // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( - // CHECK: operands=operands, - // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - // CHECK: context=_ods_context, - // CHECK: loc=loc) + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) let results = (outs I32:$i32, F32:$f32); } @@ -260,13 +256,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] - // CHECK: _ods_context = _ods_get_default_loc_context(loc) - // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( - // CHECK: operands=operands, - // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - // CHECK: context=_ods_context, - // CHECK: loc=loc) + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) let results = (outs AnyType, AnyType, AnyType); } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -494,7 +494,7 @@ regions = None {1} super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, + attributes=attributes,{2} operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) )Py"; @@ -755,17 +755,6 @@ /// Python code template appending {0} type {1} times to the results list. constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; -/// Python code template for inferring the operation results using the -/// corresponding interface: -/// - {0} is the name of the class for which the types are inferred. -constexpr const char *inferTypeInterfaceTemplate = - R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( - operands=operands, - attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - context=_ods_context, - loc=loc) -)PY"; - /// Appends the given multiline string as individual strings into /// `builderLines`. static void appendLineByLine(StringRef string, @@ -805,12 +794,8 @@ return; } - if (hasInferTypeInterface(op)) { - appendLineByLine( - llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), - builderLines); + if (hasInferTypeInterface(op)) return; - } // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { @@ -935,7 +920,8 @@ functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), - llvm::join(builderLines, "\n ")); + llvm::join(builderLines, "\n "), + hasInferTypeInterface(op) ? "" : " results=results,"); } static void emitSegmentSpec(