diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -207,6 +207,7 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorTransformOps.td SOURCES + dialects/_tensor_transform_ops_ext.py dialects/transform/tensor.py DIALECT_NAME transform EXTENSION_NAME tensor_transform) diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py @@ -0,0 +1,64 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +class MakeLoopIndependentOp: + """Specialization for MakeLoopIndependentOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, + num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_num_loops + num_loops = num_loops_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + num_loops = target_or_num_loops + + super().__init__( + transformed_type, + target, + num_loops, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/python/dialects/transform_tensor_ext.py b/mlir/test/python/dialects/transform_tensor_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_tensor_ext.py @@ -0,0 +1,40 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import tensor + + +def run(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.AnyOpType.get(), + ) + with InsertionPoint(sequence.body): + f(sequence.bodyTarget) + transform.YieldOp() + print(module) + return f + + +@run +def testMakeLoopIndependentOpCompact(target): + tensor.MakeLoopIndependentOp(target, 4) + # CHECK-LABEL: TEST: testMakeLoopIndependentOpCompact + # CHECK: = transform.tensor.make_loop_independent + # CHECK-SAME: num_loops = 4 : i64 + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMakeLoopIndependentOpTyped(target): + tensor.MakeLoopIndependentOp(transform.OperationType.get("test.dummy"), target, 4) + # CHECK-LABEL: TEST: testMakeLoopIndependentOpTyped + # CHECK: = transform.tensor.make_loop_independent + # CHECK-SAME: num_loops = 4 : i64 + # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -975,6 +975,7 @@ "mlir/dialects/_loop_transform_ops_ext.py", "mlir/dialects/_memref_transform_ops_ext.py", "mlir/dialects/_structured_transform_ops_ext.py", + "mlir/dialects/_tensor_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", "mlir/dialects/_transform_pdl_extension_ops_ext.py", ":BufferizationTransformOpsPyGen",