diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -187,6 +187,66 @@ ) +class MapCopyToThreadsOp: + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + class MatchOp: """Specialization for MatchOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -97,6 +97,44 @@ # CHECK: iterator_interchange = [1, 0] +@run +def testMapCopyToThreadsOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MapCopyToThreadsOp( + sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128 + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact + # CHECK: = transform.structured.gpu.map_copy_to_threads + # CHECK-SAME: total_num_threads = 32 + # CHECK-SAME: desired_bit_alignment = 128 + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +def testMapCopyToThreadsOpTypes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MapCopyToThreadsOp( + transform.OperationType.get("test.opA"), + transform.OperationType.get("test.opB"), + sequence.bodyTarget, + total_num_threads=32, + desired_bit_alignment=128, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes + # CHECK: = transform.structured.gpu.map_copy_to_threads + # CHECK-SAME: total_num_threads = 32 + # CHECK-SAME: desired_bit_alignment = 128 + # CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">) + + @run def testMatchOpNamesString(): sequence = transform.SequenceOp(