diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -85,17 +85,52 @@ class MatchOp: """Specialization for MatchOp class.""" + @overload @classmethod def match_op_names( - MatchOp, + cls, target: Union[Operation, Value], names: Sequence[str], + *, loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Sequence[str], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str]], + names_or_none: Optional[Sequence[str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + return cls( + result_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -57,6 +57,38 @@ # CHECK: iterator_interchange = [1, 0] +@run +def testMatchOpNames(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNames + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMatchOpNamesTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names( + transform.OperationType.get("test.dummy"), + sequence.bodyTarget, + ["test.dummy"], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNamesTyped + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy"> + + @run def testMultitileSizes(): sequence = transform.SequenceOp(