diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -783,16 +783,20 @@ self, target: Union[Operation, Value], *, - vectorize_padding: Union[bool, BoolAttr] = False, + disable_multi_reduction_to_contract_patterns: bool = False, + disable_transfer_permutation_map_lowering_patterns: bool = False, + vectorize_nd_extract: bool = False, + vectorize_padding: bool = False, loc=None, ip=None, ): pdl_operation_type = pdl.OperationType.get() - if isinstance(vectorize_padding, bool): - vectorize_padding = UnitAttr.get() super().__init__( pdl_operation_type, _get_op_result_or_value(target), + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, vectorize_padding=vectorize_padding, loc=loc, ip=ip, diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -560,17 +560,49 @@ @run -def testVectorize(): +def testVectorizeAllAttrs(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() ) with InsertionPoint(sequence.body): - structured.VectorizeOp(sequence.bodyTarget, vectorize_padding=True) + structured.VectorizeOp( + sequence.bodyTarget, + disable_multi_reduction_to_contract_patterns=True, + disable_transfer_permutation_map_lowering_patterns=True, + vectorize_nd_extract=True, + vectorize_padding=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testVectorizeAllAttrs + # CHECK: transform.sequence + # CHECK: = transform.structured.vectorize + # CHECK-SAME: disable_multi_reduction_to_contract_patterns + # CHECK-SAME: disable_transfer_permutation_map_lowering_patterns + # CHECK-SAME: vectorize_nd_extract + # CHECK-SAME: vectorize_padding + + +@run +def testVectorizeNoAttrs(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.VectorizeOp( + sequence.bodyTarget, + disable_multi_reduction_to_contract_patterns=False, + disable_transfer_permutation_map_lowering_patterns=False, + vectorize_nd_extract=False, + vectorize_padding=False, + ) transform.YieldOp() - # CHECK-LABEL: TEST: testVectorize + # CHECK-LABEL: TEST: testVectorizeNoAttrs # CHECK: transform.sequence # CHECK: = transform.structured.vectorize - # CHECK: {vectorize_padding} + # CHECK-NOT: disable_multi_reduction_to_contract_patterns + # CHECK-NOT: disable_transfer_permutation_map_lowering_patterns + # CHECK-NOT: vectorize_nd_extract + # CHECK-NOT: vectorize_padding @run