diff --git a/mlir/examples/python/.style.yapf b/mlir/examples/python/.style.yapf new file mode 100644 --- /dev/null +++ b/mlir/examples/python/.style.yapf @@ -0,0 +1,4 @@ +[style] + based_on_style = google + column_limit = 80 + indent_width = 2 diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py --- a/mlir/examples/python/linalg_matmul.py +++ b/mlir/examples/python/linalg_matmul.py @@ -15,59 +15,69 @@ # TODO: This should be in the core API. def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: - """Creates a |func| op. + """Creates a |func| op. TODO: This should really be in the MLIR API. Returns: (operation, entry_block) """ - attrs = { - "type": TypeAttr.get(func_type), - "sym_name": StringAttr.get(name), - } - op = Operation.create("func", regions=1, attributes=attrs) - body_region = op.regions[0] - entry_block = body_region.blocks.append(*func_type.inputs) - return op, entry_block + attrs = { + "type": TypeAttr.get(func_type), + "sym_name": StringAttr.get(name), + } + op = Operation.create("func", regions=1, attributes=attrs) + body_region = op.regions[0] + entry_block = body_region.blocks.append(*func_type.inputs) + return op, entry_block -# TODO: Generate customs builder vs patching one in. -def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None): - super(linalg.MatmulOp, self).__init__( - self._ods_build_default(operands=[[lhs, rhs], [result]], - results=[], - loc=loc, - ip=ip)) +def build_matmul_buffers_func(func_name, m, k, n, dtype): + lhs_type = MemRefType.get(dtype, [m, k]) + rhs_type = MemRefType.get(dtype, [k, n]) + result_type = MemRefType.get(dtype, [m, n]) + # TODO: There should be a one-liner for this. + func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) + _, entry = FuncOp(func_name, func_type) + lhs, rhs, result = entry.arguments + with InsertionPoint(entry): + op = linalg.MatmulOp([lhs, rhs], [result]) # TODO: Implement support for SingleBlockImplicitTerminator - block = self.regions[0].blocks.append() + block = op.regions[0].blocks.append() with InsertionPoint(block): linalg.YieldOp(values=[]) -linalg.MatmulOp.__init__ = PatchMatmulOpInit + std.ReturnOp([]) -def build_matmul_func(func_name, m, k, n, dtype): - lhs_type = MemRefType.get(dtype, [m, k]) - rhs_type = MemRefType.get(dtype, [k, n]) - result_type = MemRefType.get(dtype, [m, n]) - # TODO: There should be a one-liner for this. - func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) - _, entry = FuncOp(func_name, func_type) - lhs, rhs, result = entry.arguments - with InsertionPoint(entry): - linalg.MatmulOp(lhs, rhs, result) - std.ReturnOp([]) +def build_matmul_tensors_func(func_name, m, k, n, dtype): + # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes + # from each other. + lhs_type = RankedTensorType.get([m, k], dtype) + rhs_type = RankedTensorType.get([k, n], dtype) + result_type = RankedTensorType.get([m, n], dtype) + # TODO: There should be a one-liner for this. + func_type = FunctionType.get([lhs_type, rhs_type], [result_type]) + _, entry = FuncOp(func_name, func_type) + lhs, rhs = entry.arguments + with InsertionPoint(entry): + op = linalg.MatmulOp([lhs, rhs], results=[result_type]) + # TODO: Implement support for SingleBlockImplicitTerminator + block = op.regions[0].blocks.append() + with InsertionPoint(block): + linalg.YieldOp(values=[]) + std.ReturnOp([op.result]) def run(): - with Context() as c, Location.unknown(): - module = Module.create() - # TODO: This at_block_terminator vs default construct distinction feels - # wrong and is error-prone. - with InsertionPoint.at_block_terminator(module.body): - build_matmul_func('main', 18, 32, 96, F32Type.get()) + with Context() as c, Location.unknown(): + module = Module.create() + # TODO: This at_block_terminator vs default construct distinction feels + # wrong and is error-prone. + with InsertionPoint.at_block_terminator(module.body): + build_matmul_buffers_func('main_buffers', 18, 32, 96, F32Type.get()) + build_matmul_tensors_func('main_tensors', 18, 32, 96, F32Type.get()) - print(module) - print(module.operation.get_asm(print_generic_op_form=True)) + print(module) -if __name__ == '__main__': run() +if __name__ == '__main__': + run() diff --git a/mlir/lib/Bindings/Python/.style.yapf b/mlir/lib/Bindings/Python/.style.yapf new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/.style.yapf @@ -0,0 +1,4 @@ +[style] + based_on_style = google + column_limit = 80 + indent_width = 2 diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -10,6 +10,7 @@ mlir/_dlloader.py mlir/ir.py mlir/dialects/__init__.py + mlir/dialects/_linalg.py mlir/ir.py mlir/passmanager.py mlir/transforms/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py --- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py @@ -5,7 +5,68 @@ # Re-export the parent _cext so that every level of the API can get it locally. from .. import _cext -def _segmented_accessor(elements, raw_segments, idx): +__all__ = [ + "equally_sized_accessor", + "extend_opview_class", + "get_default_loc_context", + "segmented_accessor", +] + + +def extend_opview_class(ext_module): + """Decorator to extend an OpView class from an extension module. + + Extension modules can expose various entry-points: + def select_opview_mixin(parent_opview_cls): + If defined, allows an appropriate mixin class to be selected dynamically + based on the parent OpView class. Should return NotImplemented if a + decision is not made. + + Stand-alone class with the same name as a parent OpView class (i.e. + "ReturnOp"). + + Args: + ext_module: A module from which to locate extensions. Can be None if not + available. + + Returns: + A decorator that takes an OpView subclass and further extends it as + needed. + """ + + def class_decorator(parent_opview_cls: type): + if ext_module is None: + return parent_opview_cls + mixin_cls = NotImplemented + try: + select_mixin = getattr(ext_module, "select_opview_mixin") + except AttributeError: + # Try to default resolve it. + try: + select_mixin = getattr(ext_module, parent_opview_cls.__name__) + except AttributeError: + pass + else: + mixin_cls = select_mixin(parent_opview_cls) + if mixin_cls is NotImplemented or mixin_cls is None: + return parent_opview_cls + + # Have a mixin_cls. Create an appropriate subclass. + try: + + class LocalOpView(mixin_cls, parent_opview_cls): + pass + except TypeError as e: + raise TypeError( + f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e + LocalOpView.__name__ = parent_opview_cls.__name__ + LocalOpView.__qualname__ = parent_opview_cls.__qualname__ + return LocalOpView + + return class_decorator + + +def segmented_accessor(elements, raw_segments, idx): """ Returns a slice of elements corresponding to the idx-th segment. @@ -20,8 +81,8 @@ return elements[start:end] -def _equally_sized_accessor(elements, n_variadic, n_preceding_simple, - n_preceding_variadic): +def equally_sized_accessor(elements, n_variadic, n_preceding_simple, + n_preceding_variadic): """ Returns a starting position and a number of elements per variadic group assuming equally-sized groups and the given numbers of preceding groups. @@ -42,7 +103,8 @@ start = n_preceding_simple + n_preceding_variadic * elements_per_group return start, elements_per_group -def _get_default_loc_context(location = None): + +def get_default_loc_context(location=None): """ Returns a context in which the defaulted location is created. If the location is None, takes the current location from the stack, raises ValueError if there diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg.py @@ -0,0 +1,28 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +class StructuredOpMixin: + """All structured ops use the same mixin class.""" + + def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): + if outputs and results: + raise ValueError( + "Structured ops must have outputs or results, but not both.") + super().__init__( + self._ods_build_default(operands=[list(inputs), + list(outputs)], + results=list(results), + loc=loc, + ip=ip)) + + +def select_opview_mixin(parent_opview_cls): + # TODO: This shouldn't be a heuristic: we should have a way to annotate + # the OpView to note that it is a structured op. + if ("__init__" not in parent_opview_cls.__dict__ and + hasattr(parent_opview_cls, "inputs") and + hasattr(parent_opview_cls, "outputs") and + hasattr(parent_opview_cls, "result_tensors")): + return StructuredOpMixin diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -23,12 +23,19 @@ using namespace mlir::tblgen; /// File header and includes. +/// {0} is the dialect namespace. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from . import _cext as _ods_cext -from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context +from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context _ods_ir = _ods_cext.ir + +try: + from . import _{0} as _ods_ext_module +except ImportError: + _ods_ext_module = None + )Py"; /// Template for dialect class: @@ -46,6 +53,7 @@ /// {1} is the operation name. constexpr const char *opClassTemplate = R"Py( @_ods_cext.register_operation(_Dialect) +@_ods_extend_opview_class(_ods_ext_module) class {0}(_ods_ir.OpView): OPERATION_NAME = "{1}" )Py"; @@ -706,7 +714,7 @@ AttributeClasses attributeClasses; constructAttributeMapping(records, attributeClasses); - os << fileHeader; + os << llvm::formatv(fileHeader, clDialectName.getValue()); os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec);