diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -439,8 +439,9 @@ #### Builders Presently, only a single, default builder is mapped to the `__init__` method. -Generalizing this facility is under active development. It currently accepts -arguments: +The intent is that this `__init__` method represents the *most specific* of +the builders typically generated for C++; however currently it is just the +generic form below. * One argument for each declared result: * For single-valued results: Each will accept an `mlir.ir.Type`. @@ -453,7 +454,11 @@ * `loc`: An explicit `mlir.ir.Location` to use. Defaults to the location bound to the thread (i.e. `with Location.unknown():`) or an error if none is bound nor specified. - * `context`: An explicit `mlir.ir.Context` to use. Default to the context - bound to the thread (i.e. `with Context():` or implicitly via `Location` or - `InsertionPoint` context managers) or an error if none is bound nor - specified. + * `ip`: An explicit `mlir.ir.InsertionPoint` to use. Default to the insertion + point bound to the thread (i.e. `with InsertionPoint(...):`). + +In addition, each `OpView` inherits a `build_generic` method which allows +construction via a (nested in the case of variadic) sequence of `results` and +`operands`. This can be used to get some default construction semantics for +operations that are otherwise unsupported in Python, at the expense of having +a very generic signature. diff --git a/mlir/examples/python/.style.yapf b/mlir/examples/python/.style.yapf deleted file mode 100644 --- a/mlir/examples/python/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py deleted file mode 100644 --- a/mlir/examples/python/linalg_matmul.py +++ /dev/null @@ -1,81 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# This is a work in progress example to do end2end build and code generation -# of a small linalg program with configuration options. It is currently non -# functional and is being used to elaborate the APIs. - -from typing import Tuple - -from mlir.ir import * -from mlir.dialects import linalg -from mlir.dialects import std - - -# TODO: This should be in the core API. -def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: - """Creates a |func| op. - TODO: This should really be in the MLIR API. - Returns: - (operation, entry_block) - """ - attrs = { - "type": TypeAttr.get(func_type), - "sym_name": StringAttr.get(name), - } - op = Operation.create("func", regions=1, attributes=attrs) - body_region = op.regions[0] - entry_block = body_region.blocks.append(*func_type.inputs) - return op, entry_block - - -def build_matmul_buffers_func(func_name, m, k, n, dtype): - lhs_type = MemRefType.get([m, k], dtype) - rhs_type = MemRefType.get([k, n], dtype) - result_type = MemRefType.get([m, n], dtype) - # TODO: There should be a one-liner for this. - func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) - _, entry = FuncOp(func_name, func_type) - lhs, rhs, result = entry.arguments - with InsertionPoint(entry): - op = linalg.MatmulOp([lhs, rhs], [result]) - # TODO: Implement support for SingleBlockImplicitTerminator - block = op.regions[0].blocks.append() - with InsertionPoint(block): - linalg.YieldOp(values=[]) - - std.ReturnOp([]) - - -def build_matmul_tensors_func(func_name, m, k, n, dtype): - lhs_type = RankedTensorType.get([m, k], dtype) - rhs_type = RankedTensorType.get([k, n], dtype) - result_type = RankedTensorType.get([m, n], dtype) - # TODO: There should be a one-liner for this. - func_type = FunctionType.get([lhs_type, rhs_type], [result_type]) - _, entry = FuncOp(func_name, func_type) - lhs, rhs = entry.arguments - with InsertionPoint(entry): - op = linalg.MatmulOp([lhs, rhs], results=[result_type]) - # TODO: Implement support for SingleBlockImplicitTerminator - block = op.regions[0].blocks.append() - with InsertionPoint(block): - linalg.YieldOp(values=[]) - std.ReturnOp([op.result]) - - -def run(): - with Context() as c, Location.unknown(): - module = Module.create() - # TODO: This at_block_terminator vs default construct distinction feels - # wrong and is error-prone. - with InsertionPoint.at_block_terminator(module.body): - build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get()) - build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get()) - - print(module) - - -if __name__ == '__main__': - run() diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -455,8 +455,8 @@ /// Creates an operation. See corresponding python docstring. static pybind11::object - create(std::string name, llvm::Optional> operands, - llvm::Optional> results, + create(std::string name, llvm::Optional> results, + llvm::Optional> operands, llvm::Optional attributes, llvm::Optional> successors, int regions, DefaultingPyLocation location, pybind11::object ip); @@ -498,12 +498,12 @@ pybind11::object getOperationObject() { return operationObject; } static pybind11::object - odsBuildDefault(pybind11::object cls, pybind11::list operandList, - pybind11::list resultTypeList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, DefaultingPyLocation location, - pybind11::object maybeIp); + buildGeneric(pybind11::object cls, pybind11::list resultTypeList, + pybind11::list operandList, + llvm::Optional attributes, + llvm::Optional> successors, + llvm::Optional regions, DefaultingPyLocation location, + pybind11::object maybeIp); private: PyOperation &operation; // For efficient, cast-free access from C++ diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -891,8 +891,8 @@ } py::object PyOperation::create( - std::string name, llvm::Optional> operands, - llvm::Optional> results, + std::string name, llvm::Optional> results, + llvm::Optional> operands, llvm::Optional attributes, llvm::Optional> successors, int regions, DefaultingPyLocation location, py::object maybeIp) { @@ -1039,12 +1039,12 @@ //------------------------------------------------------------------------------ py::object -PyOpView::odsBuildDefault(py::object cls, py::list operandList, - py::list resultTypeList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, - DefaultingPyLocation location, py::object maybeIp) { +PyOpView::buildGeneric(py::object cls, py::list resultTypeList, + py::list operandList, + llvm::Optional attributes, + llvm::Optional> successors, + llvm::Optional regions, + DefaultingPyLocation location, py::object maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. std::string name = py::cast(cls.attr("OPERATION_NAME")); @@ -1288,8 +1288,9 @@ } // Delegate to create. - return PyOperation::create(std::move(name), /*operands=*/std::move(operands), + return PyOperation::create(std::move(name), /*results=*/std::move(resultTypes), + /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), /*regions=*/*regions, location, maybeIp); @@ -1357,6 +1358,16 @@ // Insert before operation. (*refOperation)->checkValid(); beforeOp = (*refOperation)->get(); + } else { + // Insert at end (before null) is only valid if the block does not + // already end in a known terminator (violating this will cause assertion + // failures later). + if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { + throw py::index_error("Cannot insert operation at the end of a block " + "that already has a terminator. Did you mean to " + "use 'InsertionPoint.at_block_terminator(block)' " + "versus 'InsertionPoint(block)'?"); + } } mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); operation.setAttached(); @@ -3646,8 +3657,8 @@ py::class_(m, "Operation") .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("operands") = py::none(), py::arg("results") = py::none(), + py::arg("operands") = py::none(), py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), @@ -3681,12 +3692,11 @@ opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); - opViewClass.attr("_ods_build_default") = classmethod( - &PyOpView::odsBuildDefault, py::arg("cls"), - py::arg("operands") = py::none(), py::arg("results") = py::none(), - py::arg("attributes") = py::none(), py::arg("successors") = py::none(), - py::arg("regions") = py::none(), py::arg("loc") = py::none(), - py::arg("ip") = py::none(), + opViewClass.attr("build_generic") = classmethod( + &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), + py::arg("operands") = py::none(), py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = py::none(), + py::arg("loc") = py::none(), py::arg("ip") = py::none(), "Builds a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py @@ -7,9 +7,9 @@ class ModuleOp: """Specialization for the module op class.""" - def __init__(self, loc=None, ip=None): + def __init__(self, *, loc=None, ip=None): super().__init__( - self._ods_build_default(operands=[], results=[], loc=loc, ip=ip)) + self.build_generic(results=[], operands=[], loc=loc, ip=ip)) body = self.regions[0].blocks.append() with InsertionPoint(body): Operation.create("module_terminator") @@ -25,7 +25,8 @@ def __init__(self, name, type, - visibility, + *, + visibility=None, body_builder=None, loc=None, ip=None): @@ -34,8 +35,8 @@ - `name` is a string representing the function name. - `type` is either a FunctionType or a pair of list describing inputs and results. - - `visibility` is a string matching `public`, `private`, or `nested`. The - empty string implies a private visibility. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. - `body_builder` is an optional callback, when provided a new entry block is created and the callback is invoked with the new op as argument within an InsertionPoint context already set for the block. The callback is @@ -50,7 +51,7 @@ type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc, ip) + super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py @@ -11,11 +11,10 @@ raise ValueError( "Structured ops must have outputs or results, but not both.") super().__init__( - self._ods_build_default(operands=[list(inputs), - list(outputs)], - results=list(results), - loc=loc, - ip=ip)) + self.build_generic(results=list(results), + operands=[list(inputs), list(outputs)], + loc=loc, + ip=ip)) def select_opview_mixin(parent_opview_cls): diff --git a/mlir/test/Bindings/Python/dialects/linalg.py b/mlir/test/Bindings/Python/dialects/linalg.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/dialects/linalg.py @@ -0,0 +1,57 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import builtin +from mlir.dialects import linalg +from mlir.dialects import std + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testStructuredOpOnTensors +def testStructuredOpOnTensors(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + tensor_type = RankedTensorType.get((2, 3, 4), f32) + with InsertionPoint.at_block_terminator(module.body): + func = builtin.FuncOp(name="matmul_test", + type=FunctionType.get( + inputs=[tensor_type, tensor_type], + results=[tensor_type])) + with InsertionPoint(func.add_entry_block()): + lhs, rhs = func.entry_block.arguments + result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result + std.ReturnOp([result]) + + # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + print(module) + + +run(testStructuredOpOnTensors) + + +# CHECK-LABEL: TEST: testStructuredOpOnBuffers +def testStructuredOpOnBuffers(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + memref_type = MemRefType.get((2, 3, 4), f32) + with InsertionPoint.at_block_terminator(module.body): + func = builtin.FuncOp(name="matmul_test", + type=FunctionType.get( + inputs=[memref_type, memref_type, memref_type], + results=[])) + with InsertionPoint(func.add_entry_block()): + lhs, rhs, result = func.entry_block.arguments + linalg.MatmulOp([lhs, rhs], outputs=[result]) + std.ReturnOp([]) + + # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) + print(module) + + +run(testStructuredOpOnBuffers) diff --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py --- a/mlir/test/Bindings/Python/insertion_point.py +++ b/mlir/test/Bindings/Python/insertion_point.py @@ -125,6 +125,21 @@ run(test_insert_at_block_terminator_missing) +# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors +def test_insert_at_end_with_terminator_errors(): + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() # Module is created with a terminator. + with InsertionPoint(m.body): + try: + Operation.create("custom.op1", results=[], operands=[]) + except IndexError as e: + # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. + print(f"ERROR: {e}") + +run(test_insert_at_end_with_terminator_errors) + + # CHECK-LABEL: TEST: test_insertion_point_context def test_insertion_point_context(): ctx = Context() diff --git a/mlir/test/Bindings/Python/ods_helpers.py b/mlir/test/Bindings/Python/ods_helpers.py --- a/mlir/test/Bindings/Python/ods_helpers.py +++ b/mlir/test/Bindings/Python/ods_helpers.py @@ -30,43 +30,43 @@ ctx.allow_unregistered_dialects = True m = Module.create() with InsertionPoint.at_block_terminator(m.body): - op = TestFixedRegionsOp._ods_build_default(operands=[], results=[]) + op = TestFixedRegionsOp.build_generic(results=[], operands=[]) # CHECK: NUM_REGIONS: 2 print(f"NUM_REGIONS: {len(op.regions)}") # Including a regions= that matches should be fine. - op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=2) + op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2) print(f"NUM_REGIONS: {len(op.regions)}") # Reject greater than. try: - op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=3) + op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=3) except ValueError as e: # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 print(f"ERROR:{e}") # Reject less than. try: - op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=1) + op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=1) except ValueError as e: # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 print(f"ERROR:{e}") # If no regions specified for a variadic region op, build the minimum. - op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[]) + op = TestVariadicRegionsOp.build_generic(results=[], operands=[]) # CHECK: DEFAULT_NUM_REGIONS: 2 print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") # Should also accept an explicit regions= that matches the minimum. - op = TestVariadicRegionsOp._ods_build_default( - operands=[], results=[], regions=2) + op = TestVariadicRegionsOp.build_generic( + results=[], operands=[], regions=2) # CHECK: EQ_NUM_REGIONS: 2 print(f"EQ_NUM_REGIONS: {len(op.regions)}") # And accept greater than minimum. # Should also accept an explicit regions= that matches the minimum. - op = TestVariadicRegionsOp._ods_build_default( - operands=[], results=[], regions=3) + op = TestVariadicRegionsOp.build_generic( + results=[], operands=[], regions=3) # CHECK: GT_NUM_REGIONS: 3 print(f"GT_NUM_REGIONS: {len(op.regions)}") # Should reject less than minimum. try: - op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[], regions=1) + op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=1) except ValueError as e: # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 print(f"ERROR:{e}") @@ -89,7 +89,7 @@ v1 = add_dummy_value() t0 = IntegerType.get_signless(8) t1 = IntegerType.get_signless(16) - op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1]) + op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1]) # CHECK: %[[V0:.+]] = "custom.value" # CHECK: %[[V1:.+]] = "custom.value" # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) @@ -128,50 +128,50 @@ # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64> # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64> # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) - op = TestOp._ods_build_default( - operands=[v0, [v1, v2], v3], - results=[[t0, t1], t2, t3]) + op = TestOp.build_generic( + results=[[t0, t1], t2, t3], + operands=[v0, [v1, v2], v3]) # Now test with optional omitted. # CHECK: "custom.test_op"(%[[V0]]) # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]> # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]> # CHECK-SAME: (i32) -> i64 - op = TestOp._ods_build_default( - operands=[v0, None, None], - results=[None, None, t3]) + op = TestOp.build_generic( + results=[None, None, t3], + operands=[v0, None, None]) print(m) # And verify that errors are raised for None in a required operand. try: - op = TestOp._ods_build_default( - operands=[None, None, None], - results=[None, None, t3]) + op = TestOp.build_generic( + results=[None, None, t3], + operands=[None, None, None]) except ValueError as e: # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) print(f"OPERAND_CAST_ERROR:{e}") # And verify that errors are raised for None in a required result. try: - op = TestOp._ods_build_default( - operands=[v0, None, None], - results=[None, None, None]) + op = TestOp.build_generic( + results=[None, None, None], + operands=[v0, None, None]) except ValueError as e: # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) print(f"RESULT_CAST_ERROR:{e}") # Variadic lists with None elements should reject. try: - op = TestOp._ods_build_default( - operands=[v0, [None], None], - results=[None, None, t3]) + op = TestOp.build_generic( + results=[None, None, t3], + operands=[v0, [None], None]) except ValueError as e: # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) print(f"OPERAND_LIST_CAST_ERROR:{e}") try: - op = TestOp._ods_build_default( - operands=[v0, None, None], - results=[[None], None, t3]) + op = TestOp.build_generic( + results=[[None], None, t3], + operands=[v0, None, None]) except ValueError as e: # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) print(f"RESULT_LIST_CAST_ERROR:{e}") @@ -193,16 +193,16 @@ t0 = IntegerType.get_signless(8) t1 = IntegerType.get_signless(16) try: - op = TestOp._ods_build_default( - operands=[None, v1], - results=[t0, t1]) + op = TestOp.build_generic( + results=[t0, t1], + operands=[None, v1]) except ValueError as e: # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value print(f"ERROR: {e}") try: - op = TestOp._ods_build_default( - operands=[v0, v1], - results=[t0, None]) + op = TestOp.build_generic( + results=[t0, None], + operands=[v0, v1]) except ValueError as e: # CHECK: Result 1 of operation "custom.test_op" must be a Type print(f"ERROR: {e}") diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -20,15 +20,15 @@ // CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(variadic1) // CHECK: operands.append(non_variadic) // CHECK: if variadic2 is not None: operands.append(variadic2) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -61,15 +61,15 @@ // CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { - // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): + // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: if variadic1 is not None: results.append(variadic1) // CHECK: results.append(non_variadic) // CHECK: if variadic2 is not None: results.append(variadic2) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -103,7 +103,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOp : TestOp<"attributed_op"> { - // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None): + // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -112,8 +112,8 @@ // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -143,7 +143,7 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { - // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None): + // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -152,8 +152,8 @@ // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -173,19 +173,19 @@ // CHECK: class EmptyOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.empty" def EmptyOp : TestOp<"empty">; - // CHECK: def __init__(self, loc=None, ip=None): + // CHECK: def __init__(self, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.missing_names" def MissingNamesOp : TestOp<"missing_names"> { - // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, loc=None, ip=None): + // CHECK: def __init__(self, i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -195,8 +195,8 @@ // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(f32) // CHECK: operands.append(_gen_arg_2) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -220,14 +220,14 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { - // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None): + // CHECK: def __init__(self, non_variadic, variadic, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(non_variadic) // CHECK: operands.extend(variadic) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -247,14 +247,14 @@ // CHECK-NOT: _ODS_OPERAND_SEGMENTS // CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { - // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None): + // CHECK: def __init__(self, variadic, non_variadic, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: results.extend(variadic) // CHECK: results.append(non_variadic) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -273,13 +273,13 @@ // CHECK: class PythonKeywordOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.python_keyword" def PythonKeywordOp : TestOp<"python_keyword"> { - // CHECK: def __init__(self, in_, loc=None, ip=None): + // CHECK: def __init__(self, in_, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(in_) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -338,7 +338,7 @@ // CHECK: class SimpleOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.simple" def SimpleOp : TestOp<"simple"> { - // CHECK: def __init__(self, i64, f64, i32, f32, loc=None, ip=None): + // CHECK: def __init__(self, i64, f64, i32, f32, *, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} @@ -346,8 +346,8 @@ // CHECK: results.append(f64) // CHECK: operands.append(i32) // CHECK: operands.append(f32) - // CHECK: super().__init__(self._ods_build_default( - // CHECK: attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, // CHECK: loc=loc, ip=ip)) // CHECK: @property diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -108,9 +108,8 @@ /// {3} is the position of the current group in the group list. constexpr const char *opOneOptionalTemplate = R"Py( @property - def {0}(self); - return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} - else None + def {0}(self): + return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None )Py"; /// Template for the variadic group accessor in the single variadic group case: @@ -277,7 +276,7 @@ static bool isODSReserved(StringRef str) { static llvm::StringSet<> reserved( {"attributes", "create", "context", "ip", "operands", "print", "get_asm", - "loc", "verify", "regions", "result", "results", "self", "operation", + "loc", "verify", "regions", "results", "self", "operation", "DIALECT_NAMESPACE", "OPERATION_NAME"}); return str.startswith("_ods_") || str.endswith("_ods") || reserved.contains(str); @@ -481,8 +480,8 @@ results = [] attributes = {{} {1} - super().__init__(self._ods_build_default( - attributes=attributes, operands=operands, results=results, + super().__init__(self.build_generic( + attributes=attributes, results=results, operands=operands, loc=loc, ip=ip)) )Py"; @@ -528,8 +527,15 @@ llvm::SmallVectorImpl &operandNames) { for (int i = 0, e = op.getNumResults(); i < e; ++i) { std::string name = op.getResultName(i).str(); - if (name.empty()) - name = llvm::formatv("_gen_res_{0}", i); + if (name.empty()) { + if (op.getNumResults() == 1) { + // Special case for one result, make the default name be 'result' + // to properly match the built-in result accessor. + name = "result"; + } else { + name = llvm::formatv("_gen_res_{0}", i); + } + } name = sanitizeName(name); builderArgs.push_back(name); } @@ -637,6 +643,7 @@ op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), builderLines); + builderArgs.push_back("*"); builderArgs.push_back("loc=None"); builderArgs.push_back("ip=None"); os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "),