diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -85,8 +85,9 @@ let description = [{ This transform materializes an allocation for the targeted tensor value. It replaces all original uses of the target with the newly allocated buffer, - wrapped in a `bufferization.to_tensor` op. It returns a handle to the result - of the `to_tensor` op. + wrapped in a `bufferization.to_tensor` op. It returns a handle to the newly + allocated buffer. Furthermore, it returns a handle to the result of the + `to_tensor` op. Example: ``` @@ -116,13 +117,14 @@ #### Return modes - This operation consumes the `target` handle and produces the `transformed` - handle. It always succeeds. + This operation consumes the `target` handle and produces the `replacement` + and `allocated_buffer` handles. It always succeeds. }]; let arguments = (ins Transform_AnyValue:$target, OptionalAttr:$memory_space); - let results = (outs Transform_AnyValue:$transformed); + let results = (outs Transform_AnyValue:$allocated_buffer, + Transform_AnyValue:$replacement); let assemblyFormat = "$target attr-dict"; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -321,10 +321,12 @@ /// memref.tensor_store %t, %subview /// %0 = bufferization.to_tensor %alloc restrict writable /// -/// In addition to rewriting the IR as shown above, the result of the -/// bufferization.to_tensor op is returned. +/// In addition to rewriting the IR as shown above, this function returns the +/// newly allocated buffer. Furthermore, the result of the +/// bufferization.to_tensor op is optionally returned via `replacement`. Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, - Attribute memorySpace = {}); + Attribute memorySpace = {}, + Value *replacement = nullptr); /// Materialize a buffer allocation for the given tensor value. E.g.: /// @@ -334,8 +336,13 @@ /// /// In case `value` is a tensor.pad result, the corresponding overload is used /// internally to produce a better bufferization. +/// +/// In addition to rewriting the IR as shown above, this function returns the +/// newly allocated buffer. Furthermore, the result of the +/// bufferization.to_tensor op is optionally returned via `replacement`. Value bufferizeToAllocation(RewriterBase &rewriter, Value value, - Attribute memorySpace = {}); + Attribute memorySpace = {}, + Value *replacement = nullptr); /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -174,18 +174,25 @@ transform::TransformResults &results, transform::TransformState &state) { Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); - auto transformed = llvm::to_vector( - llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { - return linalg::bufferizeToAllocation(rewriter, v, memorySpace); - })); - results.setValues(cast(getTransformed()), transformed); + SmallVector replacements; + SmallVector allocatedBuffers; + for (Value value : state.getPayloadValues(getTarget())) { + Value replacement; + Value buffer = linalg::bufferizeToAllocation(rewriter, value, memorySpace, + &replacement); + replacements.push_back(replacement); + allocatedBuffers.push_back(buffer); + } + results.setValues(cast(getReplacement()), replacements); + results.setValues(cast(getAllocatedBuffer()), allocatedBuffers); return DiagnosedSilenceableFailure::success(); } void transform::BufferizeToAllocationOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTarget(), effects); - producesHandle(getTransformed(), effects); + producesHandle(getReplacement(), effects); + producesHandle(getAllocatedBuffer(), effects); modifiesPayload(effects); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -170,7 +170,7 @@ } Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp, - Attribute memorySpace) { + Attribute memorySpace, Value *replacement) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(padOp); Location loc = padOp.getLoc(); @@ -198,7 +198,10 @@ Value toTensorOp = rewriter.create( loc, alloc, /*restrict=*/true, /*writable=*/true); rewriter.replaceOp(padOp, toTensorOp); - return toTensorOp; + + if (replacement) + *replacement = toTensorOp; + return alloc; } /// Lower tensor.from_elements to a sequence of chained tensor.insert. @@ -329,10 +332,10 @@ } Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value, - Attribute memorySpace) { + Attribute memorySpace, Value *replacement) { // Call specialized overload for certain ops. if (auto padOp = value.getDefiningOp()) - return bufferizeToAllocation(rewriter, padOp, memorySpace); + return bufferizeToAllocation(rewriter, padOp, memorySpace, replacement); // Collect all uses. SmallVector uses = llvm::to_vector( @@ -362,7 +365,9 @@ [&]() { use->set(toTensorOp); }); } - return toTensorOp; + if (replacement) + *replacement = toTensorOp; + return alloc; } namespace { diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -33,7 +33,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value - %2 = transform.structured.bufferize_to_allocation %1 + %2, %3 = transform.structured.bufferize_to_allocation %1 } // ----- @@ -59,9 +59,9 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value - %2 = transform.structured.bufferize_to_allocation %1 + %2, %3 = transform.structured.bufferize_to_allocation %1 // Make sure that One-Shot Bufferize can bufferize the rest. - %3 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op + %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op } // ----- @@ -85,7 +85,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!transform.any_op) -> !transform.any_value - %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} + %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} } // ----- @@ -106,9 +106,9 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!transform.any_op) -> !transform.any_value - %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} + %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} // Make sure that One-Shot Bufferize can bufferize the rest. - %3 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op + %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op } // ----- @@ -128,7 +128,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["dummy.some_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_result %0[0] : (!transform.any_op) -> !transform.any_value - %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} + %2, %3 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} }