diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py new file mode 100644 --- /dev/null +++ b/mlir/examples/python/linalg_matmul.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This is a work in progress example to do end2end build and code generation +# of a small linalg program with configuration options. It is currently non +# functional and is being used to elaborate the APIs. + +from typing import Tuple + +from mlir.ir import * +from mlir.dialects import linalg +from mlir.dialects import std + + +# TODO: This should be in the core API. +def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: + """Creates a |func| op. + TODO: This should really be in the MLIR API. + Returns: + (operation, entry_block) + """ + attrs = { + "type": TypeAttr.get(func_type), + "sym_name": StringAttr.get(name), + } + op = Operation.create("func", regions=1, attributes=attrs) + body_region = op.regions[0] + entry_block = body_region.blocks.append(*func_type.inputs) + return op, entry_block + + +# TODO: Generate customs builder vs patching one in. +def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None): + super(linalg.MatmulOp, self).__init__( + self._ods_build_default(operands=[[lhs, rhs], [result]], + results=[], + loc=loc, + ip=ip)) + + +linalg.MatmulOp.__init__ = PatchMatmulOpInit + + +def build_matmul_func(func_name, m, k, n, dtype): + lhs_type = MemRefType.get(dtype, [m, k]) + rhs_type = MemRefType.get(dtype, [k, n]) + result_type = MemRefType.get(dtype, [m, n]) + # TODO: There should be a one-liner for this. + func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) + _, entry = FuncOp(func_name, func_type) + lhs, rhs, result = entry.arguments + with InsertionPoint(entry): + linalg.MatmulOp(lhs, rhs, result) + std.ReturnOp([]) + + +def run(): + with Context() as c, Location.unknown(): + module = Module.create() + # TODO: This at_block_terminator vs default construct distinction feels + # wrong and is error-prone. + with InsertionPoint.at_block_terminator(module.body): + build_matmul_func('main', 18, 32, 96, F32Type.get()) + + print(module) + + +if __name__ == '__main__': run() diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -497,6 +497,12 @@ pybind11::object getOperationObject() { return operationObject; } + static pybind11::object odsBuildDefault( + pybind11::object cls, pybind11::list operandList, + pybind11::list resultTypeList, llvm::Optional attributes, + llvm::Optional> successors, int regions, + DefaultingPyLocation location, pybind11::object maybeIp); + private: PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -130,6 +130,13 @@ // Utilities. //------------------------------------------------------------------------------ +// Helper for creating an @classmethod. +template +py::object classmethod(Func f, Args... args) { + py::object cf = py::cpp_function(f, args...); + return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); +} + /// Checks whether the given type is an integer or float type. static int mlirTypeIsAIntegerOrFloat(MlirType type) { return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || @@ -1027,6 +1034,196 @@ return py::cast(PyOpView(getRef().getObject())); } +//------------------------------------------------------------------------------ +// PyOpView +//------------------------------------------------------------------------------ + +py::object PyOpView::odsBuildDefault( + py::object cls, py::list operandList, py::list resultTypeList, + llvm::Optional attributes, + llvm::Optional> successors, int regions, + DefaultingPyLocation location, py::object maybeIp) { + auto context = location->getContext(); + // Class level operation construction metadata. + std::string name = py::cast(cls.attr("OPERATION_NAME")); + // Operand and result segment specs are either none, which does no + // variadic unpacking, or a list of ints with segment sizes, where each + // element is either a positive number (typically 1 for a scalar) or -1 to + // indicate that it is derived from the length of the same-indexed operand + // or result (implying that it is a list at that position). + py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; + + // Use static op-level region count. + if (regions < 0) { + regions = py::cast(cls.attr("_ODS_REGIONS")); + } + + // Unpack results. + std::vector resultTypes; + resultTypes.reserve(resultTypeList.size()); + if (resultSegmentSpecObj.is_none()) { + // Non-variadic result unpacking. + for (auto it : llvm::enumerate(resultTypeList)) { + try { + resultTypes.push_back(py::cast(it.value())); + if (!resultTypes.back()) + throw py::cast_error(); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + "\" must be a Type (" + + err.what() + ")"); + } + } + } else { + // Variadic result unpacking. + auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + if (resultSegmentSpec.size() != resultTypeList.size()) { + throw SetPyError(PyExc_RuntimeError, + "Unexpected mismatch in result and segment arity"); + } + resultSegmentLengths.reserve(resultTypeList.size()); + for (auto it : + llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1) { + // Unpack unary element. + try { + resultTypes.push_back(py::cast(std::get<0>(it.value()))); + if (!resultTypes.back()) + throw py::cast_error(); + resultSegmentLengths.push_back(std::get<1>(it.value())); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (" + err.what() + ")"); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + auto segment = + py::cast>(std::get<0>(it.value())); + resultTypes.insert(resultTypes.end(), segment.begin(), segment.end()); + resultSegmentLengths.push_back(segment.size()); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Sequence of Types (" + err.what() + + ")"); + } + } else { + throw SetPyError(PyExc_RuntimeError, "Unexpected segment spec"); + } + } + } + + // Unpack operands. + std::vector operands; + operands.reserve(operands.size()); + if (operandSegmentSpecObj.is_none()) { + // Non-variadic operand unpacking. + for (auto it : llvm::enumerate(operandList)) { + try { + operands.push_back(py::cast(it.value())); + if (!operands.back()) + throw py::cast_error(); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (" + err.what() + ")"); + } + } + } else { + // Variadic operand unpacking. + auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); + if (operandSegmentSpec.size() != operandList.size()) { + throw SetPyError(PyExc_RuntimeError, + "Unexpected mismatch in operand and segment arity"); + } + operandSegmentLengths.reserve(operandList.size()); + for (auto it : + llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1) { + // Unpack unary element. + try { + operands.push_back(py::cast(std::get<0>(it.value()))); + if (!operands.back()) + throw py::cast_error(); + operandSegmentLengths.push_back(std::get<1>(it.value())); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (" + err.what() + ")"); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + auto segment = + py::cast>(std::get<0>(it.value())); + operands.insert(operands.end(), segment.begin(), segment.end()); + operandSegmentLengths.push_back(segment.size()); + } catch (py::cast_error &err) { + throw SetPyError(PyExc_ValueError, + llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Sequence of Values (" + + err.what() + ")"); + } + } else { + throw SetPyError(PyExc_RuntimeError, "Unexpected segment spec"); + } + } + } + + // Merge operand/result segment lengths into attributes if needed. + if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { + // Dup. + if (attributes) { + attributes = py::dict(*attributes); + } else { + attributes = py::dict(); + } + + // Add result_segment_sizes attribute. + if (!resultSegmentLengths.empty()) { + int64_t size = resultSegmentLengths.size(); + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( + mlirRankedTensorTypeGet(1, &size, + mlirIntegerTypeGet(context->get(), 64)), + resultSegmentLengths.size(), resultSegmentLengths.data()); + (*attributes)["result_segment_sizes"] = + PyAttribute(context, segmentLengthAttr); + } + + // Add operand_segment_sizes attribute. + if (!operandSegmentLengths.empty()) { + int64_t size = operandSegmentLengths.size(); + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( + mlirRankedTensorTypeGet(1, &size, + mlirIntegerTypeGet(context->get(), 64)), + operandSegmentLengths.size(), operandSegmentLengths.data()); + (*attributes)["operand_segment_sizes"] = + PyAttribute(context, segmentLengthAttr); + } + } + + // Delegate to create. + return PyOperation::create(std::move(name), /*operands=*/std::move(operands), + /*results=*/std::move(resultTypes), + /*attributes=*/std::move(attributes), + /*successors=*/std::move(successors), + /*regions=*/regions, location, maybeIp); +} + PyOpView::PyOpView(py::object operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. @@ -3397,17 +3594,29 @@ "Context that owns the Operation") .def_property_readonly("opview", &PyOperation::createOpView); - py::class_(m, "OpView") - .def(py::init()) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly( - "context", - [](PyOpView &self) { - return self.getOperation().getContext().getObject(); - }, - "Context that owns the Operation") - .def("__str__", - [](PyOpView &self) { return py::str(self.getOperationObject()); }); + auto OpViewClass = + py::class_(m, "OpView") + .def(py::init()) + .def_property_readonly("operation", &PyOpView::getOperationObject) + .def_property_readonly( + "context", + [](PyOpView &self) { + return self.getOperation().getContext().getObject(); + }, + "Context that owns the Operation") + .def("__str__", [](PyOpView &self) { + return py::str(self.getOperationObject()); + }); + OpViewClass.attr("_ODS_REGIONS") = 0; + OpViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); + OpViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); + OpViewClass.attr("_ods_build_default") = classmethod( + &PyOpView::odsBuildDefault, py::arg("cls"), + py::arg("operands") = py::none(), py::arg("results") = py::none(), + py::arg("attributes") = py::none(), py::arg("successors") = py::none(), + py::arg("regions") = -1, py::arg("loc") = py::none(), + py::arg("ip") = py::none(), + "Builds a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- // Mapping of PyRegion. diff --git a/mlir/test/Bindings/Python/ods_helpers.py b/mlir/test/Bindings/Python/ods_helpers.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ods_helpers.py @@ -0,0 +1,158 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +def add_dummy_value(): + return Operation.create( + "custom.value", + results=[IntegerType.get_signless(32)]).result + + +def testOdsBuildDefaultImplicitRegions(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_REGIONS = 2 + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + op = TestOp._ods_build_default(operands=[], results=[]) + # CHECK: NUM_REGIONS: 2 + print(f"NUM_REGIONS: {len(op.regions)}") + +run(testOdsBuildDefaultImplicitRegions) + + +def testOdsBuildDefaultNonVariadic(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1]) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) + # CHECK-NOT: operand_segment_sizes + # CHECK-NOT: result_segment_sizes + # CHECK-SAME: : (i32, i32) -> (i8, i16) + print(m) + +run(testOdsBuildDefaultNonVariadic) + + +def testOdsBuildDefaultSizedVariadic(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_OPERAND_SEGMENTS = [1, -1, 1] + _ODS_RESULT_SEGMENTS = [-1, 1, 1] + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + v2 = add_dummy_value() + v3 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + t2 = IntegerType.get_signless(32) + t3 = IntegerType.get_signless(64) + op = TestOp._ods_build_default( + operands=[v0, [v1, v2], v3], + results=[[t0, t1], t2, t3]) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: %[[V2:.+]] = "custom.value" + # CHECK: %[[V3:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) + # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : tensor<3xi64> + # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : tensor<3xi64> + # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) + print(m) + + +run(testOdsBuildDefaultSizedVariadic) + + +def testOdsBuildDefaultSizedVariadicCastError(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_OPERAND_SEGMENTS = [1, -1] + _ODS_RESULT_SEGMENTS = [-1, 1] + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + try: + op = TestOp._ods_build_default( + operands=[v0, v1], + results=[[t0], t1]) + except ValueError as e: + # CHECK: ERROR: Operand 1 of operation "custom.test_op" must be a Sequence of Values + print(f"ERROR: {e}") + try: + op = TestOp._ods_build_default( + operands=[v0, [v1]], + results=[t0, t1]) + except ValueError as e: + # CHECK: ERROR: Result 0 of operation "custom.test_op" must be a Sequence of Types + print(f"ERROR: {e}") + +run(testOdsBuildDefaultSizedVariadicCastError) + + +def testOdsBuildDefaultCastError(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + try: + op = TestOp._ods_build_default( + operands=[None, v1], + results=[t0, t1]) + except ValueError as e: + # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value + print(f"ERROR: {e}") + try: + op = TestOp._ods_build_default( + operands=[v0, v1], + results=[t0, None]) + except ValueError as e: + # CHECK: Result 1 of operation "custom.test_op" must be a Type + print(f"ERROR: {e}") + +run(testOdsBuildDefaultCastError) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -17,23 +17,18 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedOperandsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands" +// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operand_segment_sizes_ods = _ods_array.array('L') - // CHECK: operands += [*variadic1] - // CHECK: operand_segment_sizes_ods.append(len(variadic1)) + // CHECK: operands.append(variadic1) // CHECK: operands.append(non_variadic) - // CHECK: operand_segment_sizes_ods.append(1) // CHECK: if variadic2 is not None: operands.append(variadic2) - // CHECK: operand_segment_sizes_ods.append(0 if variadic2 is None else 1) - // CHECK: attributes["operand_segment_sizes"] = _ods_ir.DenseElementsAttr.get(operand_segment_sizes_ods, - // CHECK: context=_ods_get_default_loc_context(loc)) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -63,23 +58,18 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" +// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: result_segment_sizes_ods = _ods_array.array('L') // CHECK: if variadic1 is not None: results.append(variadic1) - // CHECK: result_segment_sizes_ods.append(0 if variadic1 is None else 1) // CHECK: results.append(non_variadic) - // CHECK: result_segment_sizes_ods.append(1) # non_variadic // CHECK: if variadic2 is not None: results.append(variadic2) - // CHECK: result_segment_sizes_ods.append(0 if variadic2 is None else 1) - // CHECK: attributes["result_segment_sizes"] = _ods_ir.DenseElementsAttr.get(result_segment_sizes_ods, - // CHECK: context=_ods_get_default_loc_context(loc)) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -110,6 +100,8 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attributed_op" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOp : TestOp<"attributed_op"> { // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None): // CHECK: operands = [] @@ -120,8 +112,8 @@ // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -148,6 +140,8 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None): // CHECK: operands = [] @@ -158,8 +152,8 @@ // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -183,8 +177,8 @@ // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.empty", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -201,8 +195,8 @@ // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(f32) // CHECK: operands.append(_gen_arg_2) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -223,15 +217,17 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(non_variadic) - // CHECK: operands += [*variadic] - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results, + // CHECK: operands.extend(variadic) + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -248,15 +244,17 @@ // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: results += [*variadic] + // CHECK: results.extend(variadic) // CHECK: results.append(non_variadic) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -280,8 +278,8 @@ // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(in_) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -348,8 +346,8 @@ // CHECK: results.append(f64) // CHECK: operands.append(i32) // CHECK: operands.append(f32) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.simple", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -26,7 +26,6 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -import array as _ods_array from . import _cext as _ods_cext from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context _ods_ir = _ods_cext.ir @@ -51,6 +50,18 @@ OPERATION_NAME = "{1}" )Py"; +/// Template for class level declarations of operand and result +/// segment specs. +/// {0} is either "OPERAND" or "RESULT" +/// {1} is the segment spec +/// Each segment spec is either None (default) or an array of integers +/// where: +/// 1 = single element (expect non sequence operand/result) +/// -1 = operand/result is a sequence corresponding to a variadic +constexpr const char *opClassSizedSegmentsTemplate = R"Py( + _ODS_{0}_SEGMENTS = {1} +)Py"; + /// Template for single-element accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; @@ -446,18 +457,17 @@ } /// Template for the default auto-generated builder. -/// {0} is the operation name; -/// {1} is a comma-separated list of builder arguments, including the trailing +/// {0} is a comma-separated list of builder arguments, including the trailing /// `loc` and `ip`; -/// {2} is the code populating `operands`, `results` and `attributes` fields. +/// {1} is the code populating `operands`, `results` and `attributes` fields. constexpr const char *initTemplate = R"Py( - def __init__(self, {1}): + def __init__(self, {0}): operands = [] results = [] attributes = {{} - {2} - super().__init__(_ods_ir.Operation.create( - "{0}", attributes=attributes, operands=operands, results=results, + {1} + super().__init__(self._ods_build_default( + attributes=attributes, operands=operands, results=results, loc=loc, ip=ip)) )Py"; @@ -472,37 +482,10 @@ constexpr const char *optionalAppendTemplate = "if {1} is not None: {0}s.append({1})"; -/// Template for appending a variadic element to the operand/result list. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]"; - -/// Template for setting up the segment sizes buffer. -constexpr const char *segmentDeclarationTemplate = - "{0}_segment_sizes_ods = _ods_array.array('L')"; - -/// Template for attaching segment sizes to the attribute list. -constexpr const char *segmentAttributeTemplate = - R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods, - context=_ods_get_default_loc_context(loc)))Py"; - -/// Template for appending the unit size to the segment sizes. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *singleElementSegmentTemplate = - "{0}_segment_sizes_ods.append(1) # {1}"; - -/// Template for appending 0/1 for an optional element to the segment sizes. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *optionalSegmentTemplate = - "{0}_segment_sizes_ods.append(0 if {1} is None else 1)"; - -/// Template for appending the length of a variadic group to the segment sizes. +/// Template for appending a a list of elements to the operand/result list. /// {0} is either 'operand' or 'result'; /// {1} is the field name. -constexpr const char *variadicSegmentTemplate = - "{0}_segment_sizes_ods.append(len({1}))"; +constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; /// Template for setting an attribute in the operation builder. /// {0} is the attribute name; @@ -584,11 +567,7 @@ llvm::function_ref getNumElements, llvm::function_ref getElement) { - // The segment sizes buffer only has to be populated if there attr-sized - // segments trait is present. - bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; - if (includeSegments) - builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind)); + bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; // For each element, find or generate a name. for (int i = 0, e = getNumElements(op); i < e; ++i) { @@ -596,28 +575,28 @@ std::string name = names[i]; // Choose the formatting string based on the element kind. - llvm::StringRef formatString, segmentFormatString; + llvm::StringRef formatString; if (!element.isVariableLength()) { formatString = singleElementAppendTemplate; - segmentFormatString = singleElementSegmentTemplate; } else if (element.isOptional()) { formatString = optionalAppendTemplate; - segmentFormatString = optionalSegmentTemplate; } else { assert(element.isVariadic() && "unhandled element group type"); - formatString = variadicAppendTemplate; - segmentFormatString = variadicSegmentTemplate; + // If emitting with sizedSegments, then we add the actual list typed + // element using the singleElementAppendTemplate. Otherwise, we extend + // the actual operands. + if (sizedSegments) { + // Append the list as is. + formatString = singleElementAppendTemplate; + } else { + // Append the list elements. + formatString = multiElementAppendTemplate; + } } // Add the lines. builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); - if (includeSegments) - builderLines.push_back( - llvm::formatv(segmentFormatString.data(), kind, name)); } - - if (includeSegments) - builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind)); } /// Emits a default builder constructing an operation from the list of its @@ -645,8 +624,7 @@ builderArgs.push_back("loc=None"); builderArgs.push_back("ip=None"); - os << llvm::formatv(initTemplate, op.getOperationName(), - llvm::join(builderArgs, ", "), + os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), llvm::join(builderLines, "\n ")); } @@ -659,12 +637,43 @@ } } +static void emitSegmentSpec( + const Operator &op, const char *kind, + llvm::function_ref getNumElements, + llvm::function_ref + getElement, + raw_ostream &os) { + std::string segmentSpec("["); + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + if (element.isVariableLength()) { + segmentSpec.append("-1,"); + } else if (element.isOptional()) { + segmentSpec.append("0,"); + } else { + segmentSpec.append("1,"); + } + } + segmentSpec.append("]"); + + os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); +} + /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, const AttributeClasses &attributeClasses, raw_ostream &os) { os << llvm::formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); + + // Sized segments. + if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { + emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); + } + if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { + emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); + } + emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitAttributeAccessors(op, attributeClasses, os);