diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -571,107 +571,77 @@ class TileOp: - """Specialization for TileOp class.""" + """Specialization for TileOp class.""" - @overload - def __init__( + @overload + def __init__( self, loop_types: Union[Type, List[Type]], target: Union[Operation, Value], *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - ... + ... - @overload - def __init__( + @overload + def __init__( self, target: Union[Operation, Value, OpView], *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - ... + ... - def __init__( + def __init__( self, loop_types_or_target: Union[Type, List[Type], Operation, Value], target_or_none: Optional[Union[Operation, Value, OpView]] = None, *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - if interchange is None: - interchange = [] - if sizes is None: - sizes = [] - - static_sizes = [] - dynamic_sizes = [] - if isinstance(sizes, ArrayAttr): - sizes_attr = sizes - else: - for size in sizes: - if isinstance(size, int): - static_sizes.append(size) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = DenseI64ArrayAttr.get(static_sizes) + ( + dynamic_sizes, + static_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(sizes) - num_loops = sum( - v if v == 0 else 1 for v in self.__extract_values(sizes_attr) - ) - if scalable_sizes is None: - scalable_sizes = [False] * len(self.__extract_values(sizes_attr)) + num_loops = sum(v if v == 0 else 1 for v in static_sizes) - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." - else: - loop_types = ( - ([loop_types_or_target] * num_loops) - if isinstance(loop_types_or_target, Type) - else loop_types_or_target - ) - target = target_or_none + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none - target = _get_op_result_or_value(target) + target = _get_op_result_or_value(target) - super().__init__( + super().__init__( target.type, loop_types, target, dynamic_sizes=dynamic_sizes, - static_sizes=sizes_attr, + static_sizes=static_sizes, interchange=interchange, scalable_sizes=scalable_sizes, loc=loc, ip=ip, ) - def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: - if not attr: - return [] - return [element for element in attr] - class TileToForallOp: """Specialization for TileToForallOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -486,6 +486,22 @@ # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall"> +@run +def testTileScalable(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.TileOp( + sequence.bodyTarget, + sizes=[4, [2]], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileScalable + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]] + + @run def testTileToForallCompact(): sequence = transform.SequenceOp(