diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -132,6 +132,15 @@ dialects/_memref_ops_ext.py DIALECT_NAME memref) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MLProgramOps.td + SOURCES + dialects/ml_program.py + dialects/_ml_program_ops_ext.py + DIALECT_NAME ml_program) + declare_mlir_python_sources( MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/MLProgramOps.td b/mlir/python/mlir/dialects/MLProgramOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/MLProgramOps.td @@ -0,0 +1,15 @@ +//===-- MLProgramOps.td - Entry point for MLProgramOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MLPROGRAM_OPS +#define PYTHON_BINDINGS_MLPROGRAM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/MLProgram/IR/MLProgramOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -0,0 +1,116 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from typing import Union + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from ._ml_program_ops_gen import * + + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +class FuncOp: + """Specialization for the func op class.""" + + def __init__(self, + name, + type, + *, + visibility=None, + body_builder=None, + loc=None, + ip=None): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = StringAttr.get( + str(visibility)) if visibility is not None else None + super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError('External function does not have a body') + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError('The function already has an entry block!') + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/ml_program.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._ml_program_ops_gen import * diff --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/ml_program.py @@ -0,0 +1,28 @@ +# RUN: %PYTHON %s | FileCheck %s +# This is just a smoke test that the dialect is functional. + +from mlir.ir import * +from mlir.dialects import ml_program + + +def constructAndPrintInModule(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f() + print(module) + return f + + +# CHECK-LABEL: testFuncOp +@constructAndPrintInModule +def testFuncOp(): + # CHECK: ml_program.func @foobar(%arg0: si32) -> si32 + f = ml_program.FuncOp( + name="foobar", + type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)])) + block = f.add_entry_block() + with InsertionPoint(block): + # CHECK: ml_program.return + ml_program.ReturnOp([block.arguments[0]]) diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -378,6 +378,49 @@ ], ) +##---------------------------------------------------------------------------## +# MLProgram dialect. +##---------------------------------------------------------------------------## + +td_library( + name = "MLProgramOpsPyTdFiles", + srcs = [ + "//mlir:include/mlir/Bindings/Python/Attributes.td", + ], + includes = ["../include"], + deps = [ + "//mlir:MLProgramOpsTdFiles", + "//mlir:OpBaseTdFiles", + ], +) + +gentbl_filegroup( + name = "MLProgramOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=ml_program", + ], + "mlir/dialects/_ml_program_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/MLProgramOps.td", + deps = [ + ":MLProgramOpsPyTdFiles", + ], +) + +filegroup( + name = "MLProgramOpsPyFiles", + srcs = [ + "mlir/dialects/_ml_program_ops_ext.py", + "mlir/dialects/ml_program.py", + ":MLProgramOpsPyGen", + ], +) + ##---------------------------------------------------------------------------## # PDL dialect. ##---------------------------------------------------------------------------##