diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -85,17 +85,52 @@ class MatchOp: """Specialization for MatchOp class.""" + @overload @classmethod def match_op_names( - MatchOp, + cls, target: Union[Operation, Value], names: Sequence[str], + *, loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Sequence[str], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str]], + names_or_none: Optional[Sequence[str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + return cls( + result_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc,