diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -29,6 +29,32 @@ ) +class ApplyPatternsOp: + + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + operands = [] + operands.append(_get_op_result_or_value(target)) + super().__init__( + self.build_generic(attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip)) + self.regions[0].blocks.append() + + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] + + class testGetParentOp: def __init__( diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -171,6 +171,37 @@ # CHECK: = merge_handles %[[ARG1]] +@run +def testApplyPatternsOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOpCompact + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: !transform.any_op + + +@run +def testApplyPatternsOpWithType(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], + transform.OperationType.get('test.dummy') + ) + with InsertionPoint(sequence.body): + with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns): + transform.ApplyCanonicalizationPatternsOp() + transform.YieldOp() + # CHECK-LABEL: TEST: testApplyPatternsOp + # CHECK: apply_patterns to + # CHECK: transform.apply_patterns.canonicalization + # CHECK: !transform.op<"test.dummy"> + + @run def testReplicateOp(): with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())