diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -355,6 +355,61 @@ endif() endfunction() +# Function: declare_mlir_dialect_extension_python_bindings +# Helper to generate source groups for dialect extensions, including both +# static source files and a TD_FILE to generate wrappers. +# +# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}. +# +# Arguments: +# ROOT_DIR: Same as for declare_mlir_python_sources(). +# ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names +# for the subordinate source groups are derived from this. +# TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR). +# DIALECT_NAME: Python name of the dialect. +# EXTENSION_NAME: Python name of the dialect extension. +# SOURCES: Same as declare_mlir_python_sources(). +# SOURCES_GLOB: Same as declare_mlir_python_sources(). +# DEPENDS: Additional dependency targets. +function(declare_mlir_dialect_extension_python_bindings) + cmake_parse_arguments(ARG + "" + "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME" + "SOURCES;SOURCES_GLOB;DEPENDS" + ${ARGN}) + # Source files. + set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}") + declare_mlir_python_sources(${_extension_target} + ROOT_DIR "${ARG_ROOT_DIR}" + ADD_TO_PARENT "${ARG_ADD_TO_PARENT}" + SOURCES "${ARG_SOURCES}" + SOURCES_GLOB "${ARG_SOURCES_GLOB}" + ) + + # Tablegen + if(ARG_TD_FILE) + set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen") + set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}") + get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}") + set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py") + set(LLVM_TARGET_DEFINITIONS ${td_file}) + mlir_tablegen("${output_filename}" -gen-python-op-bindings + -bind-dialect=${ARG_DIALECT_NAME} + -dialect-extension=${ARG_EXTENSION_NAME}) + add_public_tablegen_target(${tblgen_target}) + if(ARG_DEPENDS) + add_dependencies(${tblgen_target} ${ARG_DEPENDS}) + endif() + + declare_mlir_python_sources("${_extension_target}.ops_gen" + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT "${_extension_target}" + SOURCES "${output_filename}" + ) + endif() +endfunction() + # Function: mlir_python_setup_extension_rpath # Sets RPATH properties on a target, assuming that it is being output to # an _mlir_libs directory with all other libraries. For static linkage, diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -116,6 +116,25 @@ DIALECT_NAME linalg DEPENDS LinalgOdsGen) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformOps.td + SOURCES + dialects/_transform_ops_ext.py + dialects/transform/__init__.py + DIALECT_NAME transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LinalgStructuredTransformOps.td + SOURCES + dialects/_structured_transform_ops_ext.py + dialects/transform/structured.py + DIALECT_NAME transform + EXTENSION_NAME structured_transform) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td @@ -0,0 +1,21 @@ +//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS +#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td" + +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/TransformOps.td b/mlir/python/mlir/dialects/TransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformOps.td @@ -0,0 +1,15 @@ +//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_OPS +#define PYTHON_BINDINGS_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Transform/IR/TransformOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -0,0 +1,178 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Union + +IntOrAttrList = Sequence[Union[IntegerAttr, int]] +OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] + + +def _get_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: + """Creates an array attribute from its operand.""" + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + return ArrayAttr.get(values) + + +def _get_int_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]] +) -> ArrayAttr: + """Creates an integer array attribute from its operand. + + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, possibly intersperced, to + create a new array attribute containing integer attributes. Expects the + thread-local MLIR context to have been set by the context manager. + """ + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + attributes = [] + for value in values: + if isinstance(value, IntegerAttr): + attributes.append(value) + else: + attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value)) + return ArrayAttr.get(attributes) + + +def _get_int_int_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, + IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an array attribute containing array attributes of integers. + + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, potentially interpserced, to + create a new array-of-array attribute. Expects the thread-local MLIR context + to have been set by the context manager. + """ + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + return ArrayAttr.get([_get_int_array_attr(value) for value in values]) + + +class InterchangeOp: + """Specialization for InterchangeOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + interchange_attr = _get_int_array_attr(iterator_interchange) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + iterator_interchange=interchange_attr, + loc=loc, + ip=ip) + + +class PadOp: + """Specialization for PadOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + padding_values: Optional[Union[ArrayAttr, + Sequence[Attribute]]] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + hoist_paddings: OptionalIntList = None, + transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ + ArrayAttr, IntOrAttrList]]]] = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + padding_values_attr = _get_array_attr(padding_values) + padding_dimensions_attr = _get_int_array_attr(padding_dimensions) + pack_paddings_attr = _get_int_array_attr(pack_paddings) + hoist_paddings_attr = _get_int_array_attr(hoist_paddings) + transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + padding_values=padding_values_attr, + padding_dimensions=padding_dimensions_attr, + pack_paddings=pack_paddings_attr, + hoist_paddings=hoist_paddings_attr, + transpose_paddings=transpose_paddings_attr, + loc=loc, + ip=ip) + + +class ScalarizeOp: + """Specialization for ScalarizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) + + +class TileOp: + """Specialization for TileOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + sizes: OptionalIntList = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + sizes_attr = _get_int_array_attr(sizes) + num_loops = sum( + v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + super().__init__( + pdl_operation_type, [pdl_operation_type] * num_loops, + _get_op_result_or_value(target), + sizes=sizes_attr, + interchange=_get_int_array_attr(interchange) if interchange else None, + loc=loc, + ip=ip) + + def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: + if not attr: + return [] + return [IntegerAttr(element).value for element in attr] + + +class VectorizeOp: + """Specialization for VectorizeOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + if isinstance(vectorize_padding, bool): + vectorize_padding = BoolAttr.get(vectorize_padding) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -0,0 +1,106 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Sequence, Union + + +def _get_symbol_ref_attr(value: Union[Attribute, str]): + if isinstance(value, Attribute): + return value + return FlatSymbolRefAttr.get(value) + + +class GetClosestIsolatedParentOp: + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class PDLMatchOp: + + def __init__(self, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + _get_symbol_ref_attr(pattern_name), + loc=loc, + ip=ip) + + +class SequenceOp: + + @overload + def __init__(self, resultsOrRoot: Sequence[Type], + optionalRoot: Optional[Union[Operation, Value]]): + ... + + @overload + def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], + optionalRoot: NoneType): + ... + + def __init__(self, resultsOrRoot=None, optionalRoot=None): + results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] + root = ( + resultsOrRoot + if not isinstance(resultsOrRoot, Sequence) else optionalRoot) + root = _get_op_result_or_value(root) if root else None + super().__init__(results_=results, root=root) + self.regions[0].blocks.append(pdl.OperationType.get()) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + +class WithPDLPatternsOp: + + def __init__(self, + target: Optional[Union[Operation, Value]] = None, + *, + loc=None, + ip=None): + super().__init__( + root=_get_op_result_or_value(target) if target else None, + loc=loc, + ip=ip) + self.regions[0].blocks.append(pdl.OperationType.get()) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + +class YieldOp: + + def __init__(self, + operands: Union[Operation, Sequence[Value]] = [], + *, + loc=None, + ip=None): + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._structured_transform_ops_gen import * diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform.py @@ -0,0 +1,84 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects import pdl + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def testSequenceOp(): + sequence = transform.SequenceOp([pdl.OperationType.get()]) + with InsertionPoint(sequence.body): + transform.YieldOp([sequence.bodyTarget]) + # CHECK-LABEL: TEST: testSequenceOp + # CHECK: = transform.sequence { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): + # CHECK: yield %[[ARG0]] : !pdl.operation + # CHECK: } : !pdl.operation + + +@run +def testNestedSequenceOp(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + nested = transform.SequenceOp(sequence.bodyTarget) + with InsertionPoint(nested.body): + doubly_nested = transform.SequenceOp([pdl.OperationType.get()], + nested.bodyTarget) + with InsertionPoint(doubly_nested.body): + transform.YieldOp([doubly_nested.bodyTarget]) + transform.YieldOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testNestedSequenceOp + # CHECK: transform.sequence { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): + # CHECK: sequence %[[ARG0]] { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): + # CHECK: = sequence %[[ARG1]] { + # CHECK: ^{{.*}}(%[[ARG2:.+]]: !pdl.operation): + # CHECK: yield %[[ARG2]] : !pdl.operation + # CHECK: } : !pdl.operation + # CHECK: } + # CHECK: } + + +@run +def testTransformPDLOps(): + withPdl = transform.WithPDLPatternsOp() + with InsertionPoint(withPdl.body): + sequence = transform.SequenceOp([pdl.OperationType.get()], + withPdl.bodyTarget) + with InsertionPoint(sequence.body): + match = transform.PDLMatchOp(sequence.bodyTarget, "pdl_matcher") + transform.YieldOp(match) + # CHECK-LABEL: TEST: testTransformPDLOps + # CHECK: transform.with_pdl_patterns { + # CHECK: ^{{.*}}(%[[ARG0:.+]]: !pdl.operation): + # CHECK: = sequence %[[ARG0]] { + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): + # CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]] + # CHECK: yield %[[RES]] : !pdl.operation + # CHECK: } : !pdl.operation + # CHECK: } + + +@run +def testGetClosestIsolatedParentOp(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + transform.GetClosestIsolatedParentOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testGetClosestIsolatedParentOp + # CHECK: transform.sequence + # CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation): + # CHECK: = get_closest_isolated_parent %[[ARG1]] diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -0,0 +1,118 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects import pdl +from mlir.dialects.transform import structured + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def testInterchange(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.InterchangeOp( + sequence.bodyTarget, + iterator_interchange=[ + IntegerAttr.get(IntegerType.get_signless(64), 1), 0 + ]) + transform.YieldOp() + # CHECK-LABEL: TEST: testInterchange + # CHECK: transform.sequence + # CHECK: transform.structured.interchange + # CHECK: iterator_interchange = [1, 0] + + +@run +def testPad(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.PadOp( + sequence.bodyTarget, + padding_values=[FloatAttr.get_f32(42.0)], + padding_dimensions=[1], + transpose_paddings=[[1, 0]]) + transform.YieldOp() + # CHECK-LABEL: TEST: testPad + # CHECK: transform.sequence + # CHECK: transform.structured.pad + # CHECK-DAG: padding_values = [4.200000e+01 : f32] + # CHECK-DAG: padding_dimensions = [1] + # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] + # CHECK-DAG: hoist_paddings = [] + # CHECK-DAG: pack_paddings = [] + + +@run +def testScalarize(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.ScalarizeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testScalarize + # CHECK: transform.structured.scalarize + + +@run +def testTileCompact(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile + # CHECK-DAG: interchange = [0, 1] + # CHECK-DAG: sizes = [4, 8] + + +@run +def testTileAttributes(): + sequence = transform.SequenceOp() + attr = ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) + ichange = ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) + with InsertionPoint(sequence.body): + structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileAttributes + # CHECK: transform.sequence + # CHECK: structured.tile + # CHECK-DAG: interchange = [0, 1] + # CHECK-DAG: sizes = [4, 8] + + +@run +def testTileZero(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.TileOp( + sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileZero + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile + # CHECK-DAG: interchange = [0, 1, 2, 3] + # CHECK-DAG: sizes = [4, 0, 2, 0] + + +@run +def testVectorize(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) + transform.YieldOp() + # CHECK-LABEL: TEST: testVectorize + # CHECK: transform.sequence + # CHECK: = transform.structured.vectorize + # CHECK: vectorize_padding = true diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -50,6 +50,10 @@ )Py"; +constexpr const char *dialectExtensionTemplate = R"Py( +from ._{0}_ops_gen import _Dialect +)Py"; + /// Template for operation class: /// {0} is the Python class name; /// {1} is the operation name. @@ -270,6 +274,10 @@ llvm::cl::desc("The dialect to run the generator for"), llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); +static llvm::cl::opt clDialectExtensionName( + "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), + llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); + using AttributeClasses = DenseMap; /// Checks whether `str` is a Python keyword. @@ -1014,8 +1022,14 @@ AttributeClasses attributeClasses; constructAttributeMapping(records, attributeClasses); - os << llvm::formatv(fileHeader, clDialectName.getValue()); - os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); + bool isExtension = !clDialectExtensionName.empty(); + os << llvm::formatv(fileHeader, isExtension + ? clDialectExtensionName.getValue() + : clDialectName.getValue()); + if (isExtension) + os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); + else + os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -825,6 +825,74 @@ ], ) +##---------------------------------------------------------------------------## +# Transform dialect and extensions. +##---------------------------------------------------------------------------## + +td_library( + name = "TransformOpsPyTdFiles", + srcs = [ + "//mlir:include/mlir/Bindings/Python/Attributes.td", + ], + deps = [ + "//mlir:OpBaseTdFiles", + "//mlir:TransformDialectTdFiles", + ], +) + +gentbl_filegroup( + name = "TransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + ], + "mlir/dialects/_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/TransformOps.td", + deps = [ + ":TransformOpsPyTdFiles", + ], +) + +gentbl_filegroup( + name = "StructuredTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=structured_transform", + ], + "mlir/dialects/_structured_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/LinalgStructuredTransformOps.td", + deps = [ + ":TransformOpsPyTdFiles", + "//mlir:LinalgTransformOpsTdFiles", + ], +) + +filegroup( + name = "TransformOpsPyFiles", + srcs = [ + "mlir/dialects/_structured_transform_ops_ext.py", + "mlir/dialects/_transform_ops_ext.py", + ":StructuredTransformOpsPyGen", + ":TransformOpsPyGen", + ], +) + +filegroup( + name = "TransformOpsPackagePyFiles", + srcs = glob(["mlir/dialects/transform/*.py"]), +) + ##---------------------------------------------------------------------------## # Vector dialect. ##---------------------------------------------------------------------------##