diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -425,7 +425,8 @@ pybind11::object parentKeepAlive = pybind11::object()); /// Gets the backing operation. - MlirOperation get() { + operator MlirOperation() const { return get(); } + MlirOperation get() const { checkValid(); return operation; } @@ -440,7 +441,7 @@ assert(!attached && "operation already attached"); attached = true; } - void checkValid(); + void checkValid() const; /// Gets the owning block or raises an exception if the operation has no /// owning block. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -793,7 +793,7 @@ return created; } -void PyOperation::checkValid() { +void PyOperation::checkValid() const { if (!valid) { throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); } @@ -817,7 +817,7 @@ PyFileAccumulator accum(fileObject, binary); py::gil_scoped_release(); - mlirOperationPrintWithFlags(operation.get(), flags, accum.getCallback(), + mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); } @@ -1044,7 +1044,8 @@ (*refOperation)->checkValid(); beforeOp = (*refOperation)->get(); } - mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation.get()); + mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); + ; operation.setAttached(); } @@ -2763,6 +2764,20 @@ return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") + .def_property_readonly( + "result", + [](PyOperationBase &self) { + auto &operation = self.getOperation(); + if (mlirOperationGetNumResults(operation) != 1) { + throw SetPyError( + PyExc_ValueError, + "The .result property can only be used for operations that " + "have one result (otherwise, use the 'results' collection)."); + } + return PyOpResult(operation.getRef(), + mlirOperationGetResult(operation, 0)); + }, + "Shortcut to get an op result if it has only one (error otherwise).") .def("__iter__", [](PyOperationBase &self) { return PyRegionIterator(self.getOperation().getRef()); diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -474,6 +474,7 @@ run(testOperationPrint) +# CHECK-LABEL: TEST: testKnownOpView def testKnownOpView(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -503,3 +504,36 @@ print(repr(custom)) run(testKnownOpView) + + +# CHECK-LABEL: TEST: testSingleResultProperty +def testSingleResultProperty(): + with Context(), Location.unknown(): + Context.current.allow_unregistered_dialects = True + module = Module.parse(r""" + "custom.no_result"() : () -> () + %0:2 = "custom.two_result"() : () -> (f32, f32) + %1 = "custom.one_result"() : () -> f32 + """) + print(module) + + try: + module.body.operations[0].result + except ValueError as e: + # CHECK: The .result property can only be used for operations that have one result (otherwise, use the 'results' collection). + print(e) + else: + assert False, "Expected exception" + + try: + module.body.operations[1].result + except ValueError as e: + # CHECK: The .result property can only be used for operations that have one result (otherwise, use the 'results' collection). + print(e) + else: + assert False, "Expected exception" + + # CHECK: %1 = "custom.one_result"() : () -> f32 + print(module.body.operations[2]) + +run(testSingleResultProperty)