diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -9,7 +9,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union, overload +from typing import List, Optional, Sequence, Tuple, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -17,6 +17,40 @@ BoolOrAttrList = Sequence[Union[BoolAttr, bool]] OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] +MixedValues = Union[Sequence[Union[int, IntegerAttr, Operation, Value, OpView]], + ArrayAttr, Operation, Value, OpView] + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values(values: MixedValues) \ + -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(_get_op_result_or_value(size)) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] @@ -319,6 +353,92 @@ return [element for element in attr] +class TileToForallOp: + """Specialization for TileToForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, # loops_type + Union[Operation, Value, OpView]], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError("If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well.") + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + (dynamic_num_threads, + packed_num_threads, + num_threads_attr) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + (dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + class VectorizeOp: """Specialization for VectorizeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -223,6 +223,97 @@ # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> +@run +def testTileToForallCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], + transform.OperationType.get('linalg.matmul') + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallCompact + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op) + + +@run +def testTileToForallLoopsAndTileOpTypes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp( + transform.OperationType.get('scf.forall'), # loops_type + transform.OperationType.get('linalg.matmul'), # tiled_op_type + sequence.bodyTarget, num_threads=[2, 3, 4], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [2, 3, 4] tile_sizes [] + # CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">) + + +@run +def testTileToForallTileSizes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallTileSizes + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [] tile_sizes [2, 3, 4] + + +@run +def testTileToForallMixedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + sz = structured.MatchOp.match_op_names(sequence.bodyTarget, ['test.dummy']) + structured.TileToForallOp(sequence.bodyTarget, num_threads=[sz, 3, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMixedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads [%{{.*}} : !pdl.operation, 3, 4] + + +@run +def testTileToForallMPackedDynamic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + n = structured.MatchOp.match_op_names(sequence.bodyTarget, ['test.dummy']) + structured.TileToForallOp(sequence.bodyTarget, num_threads=n) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMPackedDynamic + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: num_threads *(%0 : !pdl.operation) + + +@run +def testTileToForallMapping(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + mapping = Attribute.parse('[ #gpu.thread, #gpu.thread ]') + structured.TileToForallOp( + sequence.bodyTarget, num_threads=[2, 3], + mapping=mapping + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileToForallMapping + # CHECK: = transform.structured.tile_to_forall_op + # CHECK-SAME: mapping = [#gpu.thread, #gpu.thread] + + @run def testVectorize(): sequence = transform.SequenceOp(