diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -11,19 +11,67 @@ from typing import List, Optional, Sequence, Tuple, Union, overload +StaticIntLike = Union[int, IntegerAttr] +ValueLike = Union[Operation, OpView, Value] +MixedInt = Union[StaticIntLike, ValueLike] + IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] BoolOrAttrList = Sequence[Union[BoolAttr, bool]] OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] -MixedValues = Union[ - Sequence[Union[int, IntegerAttr, Operation, Value, OpView]], - ArrayAttr, - Operation, - Value, - OpView, -] +MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: Union[DynamicIndexList, ArrayAttr], +) -> tuple[list[ValueLike], list[int] | ArrayAttr, list[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices # Dispatches `MixedValues` that all represents integers in various forms into @@ -281,6 +329,43 @@ ) +class MaskedVectorizeOp: + """Specialization for MaskedVectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + vector_sizes: Union[DynamicIndexList, ArrayAttr], + *, + vectorize_nd_extract: Optional[bool] = None, + scalable_sizes: OptionalBoolList = None, + static_vector_sizes: OptionalIntList = None, + loc=None, + ip=None, + ): + if scalable_sizes is None and static_vector_sizes is None: + ( + dynamic_vector_sizes, + static_vector_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(vector_sizes) + elif scalable_sizes is None or static_vector_sizes is None: + raise TypeError( + "'scalable_sizes' and 'static_vector_sizes' must either both " + "be given explicitly or both be given as part of 'vector_sizes'." + ) + else: + dynamic_vector_sizes = vector_sizes + + super().__init__( + target, + vector_sizes=dynamic_vector_sizes, + static_vector_sizes=static_vector_sizes, + scalable_sizes=scalable_sizes, + vectorize_nd_extract=vectorize_nd_extract, + ) + + class MatchOp: """Specialization for MatchOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -199,6 +199,85 @@ # CHECK-SAME: (!transform.any_op) -> !transform.any_op +@run +def testMaskedVectorizeStatic(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMaskedVectorizeStatic + # CHECK: transform.sequence + # CHECK: transform.structured.masked_vectorize + # CHECK-SAME: vector_sizes [16, 4] + + +@run +def testMaskedVectorizeArray(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + sizes = Attribute.parse("[16, 4]") + structured.MaskedVectorizeOp(sequence.bodyTarget, sizes) + transform.YieldOp() + # CHECK-LABEL: TEST: testMaskedVectorizeArray + # CHECK: transform.sequence + # CHECK: transform.structured.masked_vectorize + # CHECK-SAME: vector_sizes [16, 4] + + +@run +def testMaskedVectorizeMixed(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"]) + sz2 = Attribute.parse("4") + structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMaskedVectorizeMixed + # CHECK: transform.sequence + # CHECK: %[[V0:.*]] = transform.structured.match + # CHECK: transform.structured.masked_vectorize + # CHECK-SAME: vector_sizes [%[[V0]] : !transform.any_op, 4] + + +@run +def testMaskedVectorizeScalable(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"]) + sz2 = Attribute.parse("4") + structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMaskedVectorizeScalable + # CHECK: transform.sequence + # CHECK-DAG: %[[V0:.*]] = transform.structured.match + # CHECK-DAG: transform.structured.masked_vectorize + # CHECK-SAME: vector_sizes [16, [%[[V0]] : !transform.any_op], [4], [8]] + + +@run +def testMaskedVectorizeArgs(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.MaskedVectorizeOp( + sequence.bodyTarget, [16, 4], vectorize_nd_extract=True + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMaskedVectorizeArgs + # CHECK: transform.sequence + # CHECK: transform.structured.masked_vectorize + # CHECK-SAME: vectorize_nd_extract + + @run def testMatchOpNamesTyped(): sequence = transform.SequenceOp(