diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -9,7 +9,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union, overload +from typing import List, Optional, Sequence, Tuple, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -319,6 +319,137 @@ return [element for element in attr] +class TileToForallOp: + """Specialization for TileToForallOp class.""" + _MixedValues = Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr, Operation, Value, OpView] + + @overload + def __init__( + self, + tiled_op_type: Type, + loops_type: Type, + target: Union[Operation, Value, OpView], + num_threads: Optional[_MixedValues] = None, + *, + tile_sizes: _MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + loops_type: Type, + target: Union[Operation, Value, OpView], + num_threads: Optional[_MixedValues] = None, + *, + tile_sizes: _MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + num_threads: Optional[_MixedValues] = None, + *, + tile_sizes: _MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + arg0: Union[Type, # tile_op_type/loops_type + Union[Operation, Value, OpView]], # target + arg1: Union[Type, # loops_type + Union[Operation, Value, OpView], # target + Optional[_MixedValues]] = None, # num_threads + arg2: Optional[Union[ + Union[Operation, Value, OpView], # target + Optional[_MixedValues]]] = None, # num_threads + arg3: Optional[_MixedValues] = None, # num_threads + *, + tile_sizes: _MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + + # `Type` arguments in the front are optional: add default values to front. + if isinstance(arg1, Type): + # First overload: all arguments provided. + args = [arg0, arg1, arg2, arg3] + elif isinstance(arg0, Type): + # Second overload: tiled_op_type missing. + target = arg1 + tiled_op_type = target.type + args = [tiled_op_type, arg0, arg1, arg2] + else: + # Last overload: both types missing. + loops_type = transform.OperationType.get('scf.forall') + target = arg0 + tiled_op_type = target.type + args = [tiled_op_type, loops_type, arg0, arg1] + (tiled_op_type, loops_type, target, num_threads) = args + print(args) + + # Unpack mixed num_threads. + (dynamic_num_threads, + packed_num_threads, + num_threads_attr) = self._dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + (dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr) = self._dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + @classmethod + def _dispatch_mixed_values(cls: Type, values: _MixedValues) \ + -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + attr_value = None + if isinstance(values, ArrayAttr): + attr_value = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(_get_op_result_or_value(size)) + attr_value = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, attr_value) + + class VectorizeOp: """Specialization for VectorizeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -223,6 +223,114 @@ # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> +@run +def testTileToForallCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], + transform.OperationType.get('linalg.matmul') + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, [2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallCompact + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">) + + +@run +def testTileToForallLoopsType(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp( + transform.AnyOpType.get(), # loops_type + sequence.bodyTarget, [2, 3, 4], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallLoopsType + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + +@run +def testTileToForallLoopsAndTileOpTypes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp( + transform.OperationType.get('linalg.matmul'), # tiled_op_type + transform.AnyOpType.get(), # loops_type + sequence.bodyTarget, [2, 3, 4], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.op<"linalg.matmul">) + + +@run +def testTileToForallTileSizes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallTileSizes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4] + + +@run +def testTileToForallMixedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + sz = structured.MatchOp.match_op_names(sequence.bodyTarget, ['test.dummy']) + structured.TileToForallOp(sequence.bodyTarget, [sz, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMixedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [%{{.*}} : !pdl.operation, 3, 4] + + +@run +def testTileToForallMPackedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + sz = structured.MatchOp.match_op_names(sequence.bodyTarget, ['test.dummy']) + structured.TileToForallOp(sequence.bodyTarget, sz) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMPackedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads *(%0 : !pdl.operation) + + +@run +def testTileToForallMapping(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + mapping = Attribute.parse('[ #gpu.thread, #gpu.thread ]') + structured.TileToForallOp( + sequence.bodyTarget, [2, 3], + mapping=mapping + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMapping + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: mapping = [#gpu.thread, #gpu.thread] + + @run def testVectorize(): sequence = transform.SequenceOp(