diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -93,6 +93,70 @@ ) +class FuseIntoContainingOp: + """Specialization for FuseIntoContainingOp class.""" + + @overload + def __init__( + self, + fused_op_type: Type, + new_containing_op_type: Type, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], + new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], + producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(fused_op_type_or_producer_op, Type): + if not isinstance(new_containing_op_type_or_containing_op, Type): + raise TypeError( + "If 'fused_op_type_or_producer_op' is a type, then " + "'new_containing_op_type_or_containing_op' is expected " + "to be one as well." + ) + fused_op_type = fused_op_type_or_producer_op + new_containing_op_type = new_containing_op_type_or_containing_op + producer_op = producer_op_or_none + containing_op = containing_op_or_none + else: + fused_op_type = transform.AnyOpType.get() + new_containing_op_type = transform.AnyOpType.get() + producer_op = fused_op_type_or_producer_op + containing_op = new_containing_op_type_or_containing_op + + super().__init__( + fused_op_type, + new_containing_op_type, + producer_op, + containing_op, + loc=loc, + ip=ip, + ) + + class GeneralizeOp: """Specialization for GeneralizeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -30,6 +30,45 @@ # CHECK: transform.structured.decompose +@run +def testFuseIntoContainingOpTypes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + containing = structured.MatchOp.match_op_names( + sequence.bodyTarget, ["test.dummy"] + ) + structured.FuseIntoContainingOp( + transform.OperationType.get("test.dummy"), + transform.OperationType.get("test.dummy"), + fused, + containing, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testFuseIntoContainingOpTypes + # CHECK: = transform.structured.fuse_into_containing_op + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">) + + +@run +def testFuseIntoContainingOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) + containing = structured.MatchOp.match_op_names( + sequence.bodyTarget, ["test.dummy"] + ) + structured.FuseIntoContainingOp(fused, containing) + transform.YieldOp() + # CHECK-LABEL: TEST: testFuseIntoContainingOpCompact + # CHECK: = transform.structured.fuse_into_containing_op + # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + @run def testGeneralize(): sequence = transform.SequenceOp(