diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -125,7 +125,7 @@ def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: if values is None: - return ArrayAttr.get([]) + return None # Turn into a Python list of Python ints. values = _get_value_list(values) @@ -148,7 +148,7 @@ If the input is None, an empty ArrayAttr is returned. """ if values is None: - return ArrayAttr.get([]) + return None # Make sure the outer level is a list. values = _get_value_list(values) @@ -493,9 +493,7 @@ self, target: Union[Operation, OpView, Value], *, - padding_values: Optional[ - Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]] - ] = None, + padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, pad_to_multiple_of: OptionalIntList = None, pack_paddings: OptionalIntList = None, @@ -506,17 +504,6 @@ loc=None, ip=None, ): - if padding_values is None: - padding_values = [] - if padding_dimensions is None: - padding_dimensions = [] - if pad_to_multiple_of is None: - pad_to_multiple_of = [] - if pack_paddings is None: - pack_paddings = [] - if transpose_paddings is None: - transpose_paddings = [] - padding_dimensions = _get_int_array_attr(padding_dimensions) pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of) pack_paddings = _get_int_array_attr(pack_paddings) diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -314,14 +314,33 @@ @run -def testPad(): +def testPadOpNoArgs(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.PadOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testPadOpNoArgs + # CHECK: transform.sequence + # CHECK: transform.structured.pad + # CHECK-NOT: copy_back_op + # CHECK-NOT: pack_paddings + # CHECK-NOT: pad_to_multiple_of + # CHECK-NOT: padding_dimensions + # CHECK-NOT: padding_values + # CHECK-NOT: transpose_paddings + + +@run +def testPadOpArgs(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): structured.PadOp( sequence.bodyTarget, - padding_values=[FloatAttr.get_f32(42.0)], + padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")], padding_dimensions=Attribute.parse("[1]"), pad_to_multiple_of=[128], pack_paddings=[0], @@ -329,7 +348,7 @@ copy_back_op="linalg.copy", ) transform.YieldOp() - # CHECK-LABEL: TEST: testPad + # CHECK-LABEL: TEST: testPadOpArgs # CHECK: transform.sequence # CHECK: transform.structured.pad # CHECK-DAG: copy_back_op = "linalg.copy"