diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -153,6 +153,15 @@ DIALECT_NAME transform EXTENSION_NAME bufferization_transform) +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") +mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.bufferization_transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform + SOURCES "dialects/_bufferization_transform_enum_gen.py") + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -8,6 +8,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +from enum import Enum from typing import Optional, overload, Union @@ -65,16 +66,31 @@ allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, create_deallocs: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - print_conflicts: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, loc=None, ip=None ): ... @overload - def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): ... def __init__( @@ -86,9 +102,10 @@ allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, create_deallocs: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - print_conflicts: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, loc=None, ip=None ): @@ -106,9 +123,10 @@ allow_unknown_ops=allow_unknown_ops, bufferize_function_boundaries=bufferize_function_boundaries, create_deallocs=create_deallocs, - test_analysis_only=test_analysis_only, - print_conflicts=print_conflicts, + function_boundary_type_conversion=function_boundary_type_conversion, memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, loc=loc, ip=ip, ) diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py --- a/mlir/python/mlir/dialects/transform/bufferization.py +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -2,4 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from .._bufferization_transform_enum_gen import * from .._bufferization_transform_ops_gen import * diff --git a/mlir/test/python/dialects/transform_bufferization_ext.py b/mlir/test/python/dialects/transform_bufferization_ext.py --- a/mlir/test/python/dialects/transform_bufferization_ext.py +++ b/mlir/test/python/dialects/transform_bufferization_ext.py @@ -88,17 +88,21 @@ allow_return_allocs=True, allow_unknown_ops=True, bufferize_function_boundaries=True, - create_deallocs=True, - test_analysis_only=True, + create_deallocs=False, + function_boundary_type_conversion=bufferization.LayoutMapOption.IDENTITY_LAYOUT_MAP, + memcpy_op="linalg.copy", print_conflicts=True, - memcpy_op="memref.copy", + test_analysis_only=True, ) transform.YieldOp() # CHECK-LABEL: TEST: testOneShotBufferizeOpAttributes # CHECK: = transform.bufferization.one_shot_bufferize + # CHECK-SAME: layout{IdentityLayoutMap} # CHECK-SAME: allow_return_allocs = true # CHECK-SAME: allow_unknown_ops = true # CHECK-SAME: bufferize_function_boundaries = true + # CHECK-SAME: create_deallocs = false + # CHECK-SAME: memcpy_op = "linalg.copy" # CHECK-SAME: print_conflicts = true # CHECK-SAME: test_analysis_only = true # CHECK-SAME: (!transform.any_op) -> !transform.any_op diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -775,6 +775,13 @@ gentbl_filegroup( name = "BufferizationTransformOpsPyGen", tbl_outs = [ + ( + [ + "-gen-python-enum-bindings", + "-bind-dialect=transform", + ], + "mlir/dialects/_bufferization_transform_enums_gen.py", + ), ( [ "-gen-python-op-bindings",