diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -139,6 +139,7 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/BufferizationTransformOps.td SOURCES + dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform EXTENSION_NAME bufferization_transform) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -0,0 +1,114 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +class EmptyTensorToAllocTensorOp: + """Specialization for EmptyTensorToAllocTensorOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.OperationType.get("bufferization.alloc_tensor") + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + loc=loc, + ip=ip, + ) + + +class OneShotBufferizeOp: + """Specialization for OneShotBufferizeOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + print_conflicts: Optional[bool] = None, + memcpy_op: Optional[str] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + print_conflicts: Optional[bool] = None, + memcpy_op: Optional[str] = None, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + allow_return_allocs=allow_return_allocs, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + create_deallocs=create_deallocs, + test_analysis_only=test_analysis_only, + print_conflicts=print_conflicts, + memcpy_op=memcpy_op, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/python/dialects/transform_bufferization_ext.py b/mlir/test/python/dialects/transform_bufferization_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_bufferization_ext.py @@ -0,0 +1,104 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import bufferization + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def testEmptyTensorToAllocTensorOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("tensor.empty"), + ) + with InsertionPoint(sequence.body): + bufferization.EmptyTensorToAllocTensorOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpCompact + # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor + # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + + +@run +def testEmptyTensorToAllocTensorOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("tensor.empty"), + ) + with InsertionPoint(sequence.body): + bufferization.EmptyTensorToAllocTensorOp( + transform.OperationType.get("bufferization.alloc_tensor"), + sequence.bodyTarget, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpTyped + # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor + # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> + + +@run +def testOneShotBufferizeOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + bufferization.OneShotBufferizeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testOneShotBufferizeOpCompact + # CHECK: = transform.bufferization.one_shot_bufferize + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testOneShotBufferizeOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + bufferization.OneShotBufferizeOp( + transform.OperationType.get("test.dummy"), + sequence.bodyTarget, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testOneShotBufferizeOpTyped + # CHECK: = transform.bufferization.one_shot_bufferize + # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> + + +@run +def testOneShotBufferizeOpAttributes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + bufferization.OneShotBufferizeOp( + sequence.bodyTarget, + allow_return_allocs=True, + allow_unknown_ops=True, + bufferize_function_boundaries=True, + create_deallocs=True, + test_analysis_only=True, + print_conflicts=True, + memcpy_op="memref.copy", + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testOneShotBufferizeOpAttributes + # CHECK: = transform.bufferization.one_shot_bufferize + # CHECK-SAME: allow_return_allocs = true + # CHECK-SAME: allow_unknown_ops = true + # CHECK-SAME: bufferize_function_boundaries = true + # CHECK-SAME: print_conflicts = true + # CHECK-SAME: test_analysis_only = true + # CHECK-SAME: (!transform.any_op) -> !transform.any_op diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -753,6 +753,25 @@ ], ) +gentbl_filegroup( + name = "BufferizationTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=bufferization_transform", + ], + "mlir/dialects/_bufferization_transform_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/BufferizationTransformOps.td", + deps = [ + "//mlir:BufferizationTransformOpsTdFiles", + ], +) + gentbl_filegroup( name = "GPUTransformOpsPyGen", tbl_outs = [ @@ -776,7 +795,6 @@ ], ) - gentbl_filegroup( name = "StructuredTransformOpsPyGen", tbl_outs = [ @@ -849,11 +867,13 @@ filegroup( name = "TransformOpsPyFiles", srcs = [ + "mlir/dialects/_bufferization_transform_ops_ext.py", "mlir/dialects/_gpu_transform_ops_ext.py", "mlir/dialects/_loop_transform_ops_ext.py", "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", "mlir/dialects/_transform_pdl_extension_ops_ext.py", + ":BufferizationTransformOpsPyGen", ":GPUTransformOpsPyGen", ":LoopTransformOpsPyGen", ":PDLTransformOpsPyGen",