diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -87,8 +87,8 @@ result in a new allocation. It replaces all original uses of the target result with the newly allocated buffer, wrapped in a `bufferization.to_tensor` op. It returns a handle to the newly allocated - buffer. Furthermore, it returns a handle to the result of the `to_tensor` - op. + buffer. Furthermore, it returns a handle that is mapped to all newly created + ops. Only bufferizable ops are that bufferize to a memory write or have an aliasing OpOperand (and do not themselves bufferize to an allocation) are @@ -121,12 +121,13 @@ #### Return modes This operation consumes the `target` handle and produces the - `allocated_buffer` handle. It always succeeds. + `allocated_buffer` and `new_ops` handles. It always succeeds. }]; let arguments = (ins TransformHandleTypeInterface:$target, OptionalAttr:$memory_space); - let results = (outs Transform_AnyValue:$allocated_buffer); + let results = (outs Transform_AnyValue:$allocated_buffer, + Transform_AnyOpType:$new_ops); let assemblyFormat = "$target attr-dict `:` type($target)"; } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -444,7 +444,7 @@ /// struct can be used as a base to create listener chains, so that multiple /// listeners can be notified of IR changes. struct ForwardingListener : public RewriterBase::Listener { - ForwardingListener(Listener *listener) : listener(listener) {} + ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {} void notifyOperationInserted(Operation *op) override { listener->notifyOperationInserted(op); @@ -453,26 +453,32 @@ listener->notifyBlockCreated(block); } void notifyOperationModified(Operation *op) override { - listener->notifyOperationModified(op); + if (auto *rewriteListener = dyn_cast(listener)) + rewriteListener->notifyOperationModified(op); } void notifyOperationReplaced(Operation *op, Operation *newOp) override { - listener->notifyOperationReplaced(op, newOp); + if (auto *rewriteListener = dyn_cast(listener)) + rewriteListener->notifyOperationReplaced(op, newOp); } void notifyOperationReplaced(Operation *op, ValueRange replacement) override { - listener->notifyOperationReplaced(op, replacement); + if (auto *rewriteListener = dyn_cast(listener)) + rewriteListener->notifyOperationReplaced(op, replacement); } void notifyOperationRemoved(Operation *op) override { - listener->notifyOperationRemoved(op); + if (auto *rewriteListener = dyn_cast(listener)) + rewriteListener->notifyOperationRemoved(op); } LogicalResult notifyMatchFailure( Location loc, function_ref reasonCallback) override { - return listener->notifyMatchFailure(loc, reasonCallback); + if (auto *rewriteListener = dyn_cast(listener)) + return rewriteListener->notifyMatchFailure(loc, reasonCallback); + return failure(); } private: - Listener *listener; + OpBuilder::Listener *listener; }; /// Move the blocks that belong to "region" before the given position in diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -171,9 +171,43 @@ // BufferizeToAllocationOp //===----------------------------------------------------------------------===// +namespace { +class NewOpsListener : public RewriterBase::ForwardingListener { +public: + using RewriterBase::ForwardingListener::ForwardingListener; + + SmallVector getNewOps() const { + return SmallVector(newOps.begin(), newOps.end()); + } + +private: + void notifyOperationInserted(Operation *op) override { + ForwardingListener::notifyOperationInserted(op); + auto inserted = newOps.insert(op); + (void)inserted; + assert(inserted.second && "expected newly created op"); + } + + void notifyOperationRemoved(Operation *op) override { + ForwardingListener::notifyOperationRemoved(op); + op->walk([&](Operation *op) { newOps.erase(op); }); + } + + DenseSet newOps; +}; +} // namespace + DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { + // Attach listener to keep track of newly created ops. + OpBuilder::Listener *previousListener = rewriter.getListener(); + auto resetListener = + llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); + NewOpsListener newOpsListener(previousListener); + rewriter.setListener(&newOpsListener); + + // Bufferize ops. Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); SmallVector allocatedBuffers; @@ -187,7 +221,10 @@ } allocatedBuffers.push_back(buffer); } + + // Set results. results.setValues(cast(getAllocatedBuffer()), allocatedBuffers); + results.set(cast(getNewOps()), newOpsListener.getNewOps()); return DiagnosedSilenceableFailure::success(); } @@ -195,6 +232,7 @@ SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); producesHandle(getAllocatedBuffer(), effects); + producesHandle(getNewOps(), effects); modifiesPayload(effects); } diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir --- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir +++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir @@ -54,7 +54,7 @@ padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 1] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %buffer = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op + %buffer, %new_ops = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op %2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op } @@ -114,6 +114,6 @@ transform.structured.masked_vectorize %pad vector_sizes [10, 12] : !transform.any_op %vector_write = transform.structured.match ops{["vector.transfer_write"]} in %arg1 : (!transform.any_op) -> !transform.any_op %mask_op = transform.get_parent_op %vector_write {op_name = "vector.mask"} : (!transform.any_op) -> !transform.any_op - %buffer = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op + %buffer, %new_ops = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op %2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op } diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -32,7 +32,17 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op + + // Ensure that one linalg.fill was generated. + %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op + + // Ensure that one memref.tensor_store was generated. + %tensor_store = transform.select "memref.tensor_store" in %new : (!transform.any_op) -> !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %tensor_store : !transform.any_op } // ----- @@ -57,7 +67,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op // Make sure that One-Shot Bufferize can bufferize the rest. %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op } @@ -81,7 +91,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op // Make sure that One-Shot Bufferize can bufferize the rest. %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op } @@ -104,7 +114,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op // Make sure that One-Shot Bufferize can bufferize the rest. %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op } @@ -121,7 +131,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below{{failed to bufferize operation}} - %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op } // ----- @@ -142,5 +152,5 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op + %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op }