diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1288,7 +1288,7 @@ std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const py::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1367,6 +1367,7 @@ if (!mlirOperands.empty()) mlirOperationStateAddOperands(&state, mlirOperands.size(), mlirOperands.data()); + state.enableResultTypeInference = inferType; if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), mlirResults.data()); @@ -1398,6 +1399,8 @@ // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); + if (!operation.ptr) + throw py::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1441,13 +1444,12 @@ // PyOpView //------------------------------------------------------------------------------ -py::object -PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, - py::list operandList, std::optional attributes, - std::optional> successors, - std::optional regions, - DefaultingPyLocation location, - const py::object &maybeIp) { +py::object PyOpView::buildGeneric( + const py::object &cls, std::optional resultTypeList, + py::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const py::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. std::string name = py::cast(cls.attr("OPERATION_NAME")); @@ -1486,49 +1488,15 @@ // Unpack results. std::vector resultTypes; - resultTypes.reserve(resultTypeList.size()); - if (resultSegmentSpecObj.is_none()) { - // Non-variadic result unpacking. - for (const auto &it : llvm::enumerate(resultTypeList)) { - try { - resultTypes.push_back(py::cast(it.value())); - if (!resultTypes.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Type (" + err.what() + ")") - .str()); - } - } - } else { - // Sized result unpacking. - auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); - if (resultSegmentSpec.size() != resultTypeList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + - "\" requires " + - llvm::Twine(resultSegmentSpec.size()) + - " result segments but was provided " + - llvm::Twine(resultTypeList.size())) - .str()); - } - resultSegmentLengths.reserve(resultTypeList.size()); - for (const auto &it : - llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { - int segmentSpec = std::get<1>(it.value()); - if (segmentSpec == 1 || segmentSpec == 0) { - // Unpack unary element. + if (resultTypeList.has_value()) { + resultTypes.reserve(resultTypeList->size()); + if (resultSegmentSpecObj.is_none()) { + // Non-variadic result unpacking. + for (const auto &it : llvm::enumerate(resultTypeList.value())) { try { - auto *resultType = py::cast(std::get<0>(it.value())); - if (resultType) { - resultTypes.push_back(resultType); - resultSegmentLengths.push_back(1); - } else if (segmentSpec == 0) { - // Allowed to be optional. - resultSegmentLengths.push_back(0); - } else { - throw py::cast_error("was None and result is not optional"); - } + resultTypes.push_back(py::cast(it.value())); + if (!resultTypes.back()) + throw py::cast_error(); } catch (py::cast_error &err) { throw py::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + @@ -1536,35 +1504,73 @@ ")") .str()); } - } else if (segmentSpec == -1) { - // Unpack sequence by appending. - try { - if (std::get<0>(it.value()).is_none()) { - // Treat it as an empty list. - resultSegmentLengths.push_back(0); - } else { - // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - resultTypes.push_back(py::cast(segmentItem)); - if (!resultTypes.back()) { - throw py::cast_error("contained a None item"); + } + } else { + // Sized result unpacking. + auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + if (resultSegmentSpec.size() != resultTypeList->size()) { + throw py::value_error((llvm::Twine("Operation \"") + name + + "\" requires " + + llvm::Twine(resultSegmentSpec.size()) + + " result segments but was provided " + + llvm::Twine(resultTypeList->size())) + .str()); + } + resultSegmentLengths.reserve(resultTypeList->size()); + for (const auto &it : llvm::enumerate( + llvm::zip(resultTypeList.value(), resultSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1 || segmentSpec == 0) { + // Unpack unary element. + try { + auto *resultType = py::cast(std::get<0>(it.value())); + if (resultType) { + resultTypes.push_back(resultType); + resultSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + resultSegmentLengths.push_back(0); + } else { + throw py::cast_error("was None and result is not optional"); + } + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (" + err.what() + ")") + .str()); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + if (std::get<0>(it.value()).is_none()) { + // Treat it as an empty list. + resultSegmentLengths.push_back(0); + } else { + // Unpack the list. + auto segment = py::cast(std::get<0>(it.value())); + for (py::object segmentItem : segment) { + resultTypes.push_back(py::cast(segmentItem)); + if (!resultTypes.back()) { + throw py::cast_error("contained a None item"); + } } + resultSegmentLengths.push_back(segment.size()); } - resultSegmentLengths.push_back(segment.size()); + } catch (std::exception &err) { + // NOTE: Sloppy to be using a catch-all here, but there are at least + // three different unrelated exceptions that can be thrown in the + // above "casts". Just keep the scope above small and catch them + // all. + throw py::value_error( + (llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Sequence of Types (" + err.what() + ")") + .str()); } - } catch (std::exception &err) { - // NOTE: Sloppy to be using a catch-all here, but there are at least - // three different unrelated exceptions that can be thrown in the - // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Result ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Sequence of Types (" + - err.what() + ")") - .str()); + } else { + throw py::value_error("Unexpected segment spec"); } - } else { - throw py::value_error("Unexpected segment spec"); } } } @@ -1694,7 +1700,8 @@ /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), - /*regions=*/*regions, location, maybeIp); + /*regions=*/*regions, location, maybeIp, + !resultTypeList); } pybind11::object PyOpView::constructDerived(const pybind11::object &cls, @@ -2854,7 +2861,7 @@ py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) + py::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -655,7 +655,8 @@ std::optional> operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip); + DefaultingPyLocation location, const pybind11::object &ip, + bool inferType = false); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); @@ -704,13 +705,12 @@ pybind11::object getOperationObject() { return operationObject; } - static pybind11::object - buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, - pybind11::list operandList, - std::optional attributes, - std::optional> successors, - std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + static pybind11::object buildGeneric( + const pybind11::object &cls, std::optional resultTypeList, + pybind11::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const pybind11::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -245,14 +245,10 @@ // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { // CHECK: def __init__(self, *, loc=None, ip=None): - // CHECK: operands = [] - // CHECK: results = [] // CHECK: _ods_context = _ods_get_default_loc_context(loc) - // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes( - // CHECK: operands=operands, - // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - // CHECK: context=_ods_context, - // CHECK: loc=loc) + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) let results = (outs I32:$i32, F32:$f32); } @@ -260,13 +256,9 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] - // CHECK: results = [] - // CHECK: _ods_context = _ods_get_default_loc_context(loc) - // CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes( - // CHECK: operands=operands, - // CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - // CHECK: context=_ods_context, - // CHECK: loc=loc) + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) let results = (outs AnyType, AnyType, AnyType); } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -494,7 +494,7 @@ regions = None {1} super().__init__(self.build_generic( - attributes=attributes, results=results, operands=operands, + attributes=attributes,{2} operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) )Py"; @@ -755,17 +755,6 @@ /// Python code template appending {0} type {1} times to the results list. constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; -/// Python code template for inferring the operation results using the -/// corresponding interface: -/// - {0} is the name of the class for which the types are inferred. -constexpr const char *inferTypeInterfaceTemplate = - R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( - operands=operands, - attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), - context=_ods_context, - loc=loc) -)PY"; - /// Appends the given multiline string as individual strings into /// `builderLines`. static void appendLineByLine(StringRef string, @@ -805,12 +794,8 @@ return; } - if (hasInferTypeInterface(op)) { - appendLineByLine( - llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), - builderLines); + if (hasInferTypeInterface(op)) return; - } // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { @@ -935,7 +920,8 @@ functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), - llvm::join(builderLines, "\n ")); + llvm::join(builderLines, "\n "), + hasInferTypeInterface(op) ? "" : " results=results,"); } static void emitSegmentSpec(