diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -425,7 +425,8 @@ pybind11::object parentKeepAlive = pybind11::object()); /// Gets the backing operation. - MlirOperation get() { + operator MlirOperation() const { return get(); } + MlirOperation get() const { checkValid(); return operation; } @@ -440,7 +441,7 @@ assert(!attached && "operation already attached"); attached = true; } - void checkValid(); + void checkValid() const; /// Gets the owning block or raises an exception if the operation has no /// owning block. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -23,6 +23,8 @@ using namespace mlir::python; using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; //------------------------------------------------------------------------------ // Docstrings (trivial, non-duplicated docstrings are included inline). @@ -631,7 +633,7 @@ getContext()->get(), {canonKey->data(), canonKey->size()}); if (mlirDialectIsNull(dialect)) { throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, - llvm::Twine("Dialect '") + key + "' not found"); + Twine("Dialect '") + key + "' not found"); } return dialect; } @@ -793,7 +795,7 @@ return created; } -void PyOperation::checkValid() { +void PyOperation::checkValid() const { if (!valid) { throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); } @@ -817,7 +819,7 @@ PyFileAccumulator accum(fileObject, binary); py::gil_scoped_release(); - mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(), + mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); } @@ -975,7 +977,7 @@ MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto opViewClass = PyGlobals::get().lookupRawOpViewClass( - llvm::StringRef(identStr.data, identStr.length)); + StringRef(identStr.data, identStr.length)); if (opViewClass) return (*opViewClass)(getRef().getObject()); return py::cast(PyOpView(getRef().getObject())); @@ -1044,7 +1046,7 @@ (*refOperation)->checkValid(); beforeOp = (*refOperation)->get(); } - mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get()); + mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); operation.setAttached(); } @@ -1158,7 +1160,7 @@ static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast value to ") + + throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } @@ -1416,9 +1418,9 @@ static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); } return orig; } @@ -1449,7 +1451,7 @@ // in C API. if (mlirAttributeIsNull(attr)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(type)).cast() + "' and expected floating point type."); } @@ -1943,7 +1945,7 @@ static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + DerivedTy::pyClassName + " (from " + origRepr + ")"); } @@ -2142,7 +2144,7 @@ } throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point or integer type."); }, @@ -2247,7 +2249,7 @@ if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point or integer type."); } @@ -2278,7 +2280,7 @@ if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point, integer, vector or " "complex " @@ -2309,7 +2311,7 @@ if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point, integer, vector or " "complex " @@ -2344,7 +2346,7 @@ if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point, integer, vector or " "complex " @@ -2390,7 +2392,7 @@ if (mlirTypeIsNull(t)) { throw SetPyError( PyExc_ValueError, - llvm::Twine("invalid '") + + Twine("invalid '") + py::repr(py::cast(elementType)).cast() + "' and expected floating point, integer, vector or " "complex " @@ -2544,7 +2546,7 @@ self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Dialect '") + name + "' not found"); + Twine("Dialect '") + name + "' not found"); } return PyDialectDescriptor(self.getRef(), dialect); }, @@ -2763,6 +2765,26 @@ return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") + .def_property_readonly( + "result", + [](PyOperationBase &self) { + auto &operation = self.getOperation(); + auto numResults = mlirOperationGetNumResults(operation); + if (numResults != 1) { + auto name = mlirIdentifierStr(mlirOperationGetName(operation)); + throw SetPyError( + PyExc_ValueError, + Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)"); + } + return PyOpResult(operation.getRef(), + mlirOperationGetResult(operation, 0)); + }, + "Shortcut to get an op result if it has only one (throws an error " + "otherwise).") .def("__iter__", [](PyOperationBase &self) { return PyRegionIterator(self.getOperation().getRef()); @@ -2931,7 +2953,7 @@ // in C API. if (mlirAttributeIsNull(type)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse attribute: '") + + Twine("Unable to parse attribute: '") + attrSpec + "'"); } return PyAttribute(context->getRef(), type); @@ -3042,8 +3064,8 @@ // in C API. if (mlirTypeIsNull(type)) { throw SetPyError(PyExc_ValueError, - llvm::Twine("Unable to parse type: '") + - typeSpec + "'"); + Twine("Unable to parse type: '") + typeSpec + + "'"); } return PyType(context->getRef(), type); }, diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -474,6 +474,7 @@ run(testOperationPrint) +# CHECK-LABEL: TEST: testKnownOpView def testKnownOpView(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -503,3 +504,36 @@ print(repr(custom)) run(testKnownOpView) + + +# CHECK-LABEL: TEST: testSingleResultProperty +def testSingleResultProperty(): + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse(r""" + "custom.no_result"() : () -> () + %0:2 = "custom.two_result"() : () -> (f32, f32) + %1 = "custom.one_result"() : () -> f32 + """) + print(module) + + try: + module.body.operations[0].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.no_result which has 0 results + print(e) + else: + assert False, "Expected exception" + + try: + module.body.operations[1].result + except ValueError as e: + # CHECK: Cannot call .result on operation custom.two_result which has 2 results + print(e) + else: + assert False, "Expected exception" + + # CHECK: %1 = "custom.one_result"() : () -> f32 + print(module.body.operations[2]) + +run(testSingleResultProperty)