diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2152,10 +2152,6 @@ }, "Returns the source location the operation was defined or derived " "from.") - .def("__iter__", - [](PyOperationBase &self) { - return PyRegionIterator(self.getOperation().getRef()); - }) .def( "__str__", [](PyOperationBase &self) { diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -195,8 +195,17 @@ # Coerce return values, add ReturnOp and rewrite func type. if return_values is None: return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results else: return_values = list(return_values) std.ReturnOp(return_values) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -124,7 +124,7 @@ def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value] + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] ) -> _cext.ir.Value: """Returns the given value or the single result of the given op. @@ -136,6 +136,8 @@ return arg.operation.result elif isinstance(arg, _cext.ir.Operation): return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] else: assert isinstance(arg, _cext.ir.Value) return arg diff --git a/mlir/test/python/dialects/builtin.py b/mlir/test/python/dialects/builtin.py --- a/mlir/test/python/dialects/builtin.py +++ b/mlir/test/python/dialects/builtin.py @@ -15,6 +15,7 @@ @run def testFromPyFunc(): with Context() as ctx, Location.unknown() as loc: + ctx.allow_unregistered_dialects = True m = builtin.ModuleOp() f32 = F32Type.get() f64 = F64Type.get() @@ -51,6 +52,14 @@ def call_binary(a, b): return binary_return(a, b) + # We expect coercion of a single result operation to a returned value. + # CHECK-LABEL: func @single_result_op + # CHECK: %0 = "custom.op1"() : () -> f32 + # CHECK: return %0 : f32 + @builtin.FuncOp.from_py_func() + def single_result_op(): + return Operation.create("custom.op1", results=[f32]) + # CHECK-LABEL: func @call_none # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> () # CHECK: return diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py --- a/mlir/test/python/dialects/math.py +++ b/mlir/test/python/dialects/math.py @@ -19,7 +19,7 @@ return mlir_math.SqrtOp(arg) # CHECK-LABEL: func @emit_sqrt( - # CHECK-SAME: %[[ARG:.*]]: f32) { + # CHECK-SAME: %[[ARG:.*]]: f32) -> f32 { # CHECK: math.sqrt %[[ARG]] : f32 # CHECK: return # CHECK: } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -40,7 +40,7 @@ print(f".verify = {module.operation.verify()}") # Get the regions and blocks from the default collections. - default_regions = list(op) + default_regions = list(op.regions) default_blocks = list(default_regions[0]) # They should compare equal regardless of how obtained. assert default_regions == regions @@ -53,7 +53,7 @@ assert default_operations == operations def walk_operations(indent, op): - for i, region in enumerate(op): + for i, region in enumerate(op.regions): print(f"{indent}REGION {i}:") for j, block in enumerate(region): print(f"{indent} BLOCK {j}:")