diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -84,6 +84,40 @@ return ArrayAttr.get(values) +class BufferizeToAllocationOp: + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[int | str | Attribute] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + ) + + class DecomposeOp: """Specialization for DecomposeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -18,6 +18,42 @@ return f +@run +def testBufferizeToAllocationOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.BufferizeToAllocationOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact + # CHECK: transform.sequence + # CHECK: transform.structured.bufferize_to_allocation + + +@run +def testBufferizeToAllocationOpArgs(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get() + ) + with InsertionPoint(sequence.body): + structured.BufferizeToAllocationOp( + sequence.bodyTarget, + memory_space=3, + memcpy_op="memref.copy", + alloc_op="memref.alloca", + bufferize_destination_only=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs + # CHECK: transform.sequence + # CHECK: transform.structured.bufferize_to_allocation + # CHECK-SAME: alloc_op = "memref.alloca" + # CHECK-SAME: bufferize_destination_only + # CHECK-SAME: memcpy_op = "memref.copy" + # CHECK-SAME: memory_space = 3 + + @run def testDecompose(): sequence = transform.SequenceOp(