diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -107,28 +107,60 @@ return (dynamic_values, packed_values, static_values) -def _get_int_int_array_attr( +def _get_value_or_attribute_value( + value_or_attr: Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: Union[Sequence[any], ArrayAttr] +) -> Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: + if values is None: + return ArrayAttr.get([]) + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: - """Creates an array attribute containing array attributes of integers. + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, potentially interpserced, to - create a new array-of-array attribute. Expects the thread-local MLIR context - to have been set by the context manager. + The input has to be a collection of collection of integers, where any + Python Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. """ if values is None: return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - if isinstance(values, list): - values = [ - ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value] - ) - for value in values - ] + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. return ArrayAttr.get(values) @@ -361,44 +393,55 @@ class PadOp: - """Specialization for PadOp class.""" + """Specialization for PadOp class.""" - def __init__( - self, - target: Union[Operation, Value], - *, - padding_values: Optional[ - Optional[Union[ArrayAttr, Sequence[Attribute]]] - ] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[ - Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] - ] = None, - loc=None, - ip=None, - ): - if transpose_paddings is None: - transpose_paddings = [] - if pack_paddings is None: - pack_paddings = [] - if padding_dimensions is None: - padding_dimensions = [] - if padding_values is None: - padding_values = [] - pdl_operation_type = pdl.OperationType.get() - transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) - super().__init__( - pdl_operation_type, - pdl_operation_type, - _get_op_result_or_value(target), - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings_attr, - loc=loc, - ip=ip, - ) + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + padding_values: Optional[ + Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]] + ] = None, + padding_dimensions: OptionalIntList = None, + pad_to_multiple_of: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + copy_back_op: Optional[Union[str, StringAttr]] = None, + loc=None, + ip=None, + ): + if padding_values is None: + padding_values = [] + if padding_dimensions is None: + padding_dimensions = [] + if pad_to_multiple_of is None: + pad_to_multiple_of = [] + if pack_paddings is None: + pack_paddings = [] + if transpose_paddings is None: + transpose_paddings = [] + + padding_dimensions = _get_int_array_attr(padding_dimensions) + pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of) + pack_paddings = _get_int_array_attr(pack_paddings) + transpose_paddings = _get_int_array_array_attr(transpose_paddings) + + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + pdl_operation_type, + target, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pad_to_multiple_of=pad_to_multiple_of, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings, + copy_back_op=copy_back_op, + loc=loc, + ip=ip, + ) class ScalarizeOp: diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -247,17 +247,22 @@ structured.PadOp( sequence.bodyTarget, padding_values=[FloatAttr.get_f32(42.0)], - padding_dimensions=[1], - transpose_paddings=[[1, 0]], + padding_dimensions=Attribute.parse("[1]"), + pad_to_multiple_of=[128], + pack_paddings=[0], + transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")], + copy_back_op="linalg.copy", ) transform.YieldOp() # CHECK-LABEL: TEST: testPad # CHECK: transform.sequence # CHECK: transform.structured.pad - # CHECK-DAG: padding_values = [4.200000e+01 : f32] + # CHECK-DAG: copy_back_op = "linalg.copy" + # CHECK-DAG: pack_paddings = [128] + # CHECK-DAG: pad_to_multiple_of = [128] # CHECK-DAG: padding_dimensions = [1] - # CHECK-DAG: transpose_paddings = {{\[}}[1, 0]] - # (pack_paddings has default values) + # CHECK-DAG: padding_values = [4.200000e+01 : f32] + # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [1, 0]] @run