diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -195,7 +195,7 @@ def match_op_names( cls, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -208,7 +208,7 @@ cls, result_type: Type, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -219,8 +219,8 @@ def match_op_names( cls, result_type_or_target: Union[Type, Operation, Value], - target_or_names: Union[Operation, Value, Sequence[str]], - names_or_none: Optional[Sequence[str]] = None, + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, *, loc=None, ip=None, @@ -234,6 +234,9 @@ target = result_type_or_target names = target_or_names + if isinstance(names, str): + names = [names] + return cls( result_type, _get_op_result_or_value(target), diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -97,14 +97,28 @@ @run -def testMatchOpNames(): +def testMatchOpNamesString(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy") + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNamesString + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMatchOpNamesList(): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) transform.YieldOp() - # CHECK-LABEL: TEST: testMatchOpNames + # CHECK-LABEL: TEST: testMatchOpNamesList # CHECK: transform.structured.match ops # CHECK-SAME: ["test.dummy"] # CHECK-SAME: (!transform.any_op) -> !transform.any_op