diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -356,6 +356,10 @@ /// an alias. Return false if the op is not bufferizable. bool bufferizesToAliasOnly(OpOperand &opOperand) const; + /// Return true if a copy can always be avoided when allocating a new tensor + /// for the given OpOperand. + bool canOmitTensorCopy(OpOperand &opOperand) const; + /// Return true if the given value is read by an op that bufferizes to a /// memory read. Also takes into account ops that create an alias but do not /// read by themselves (e.g., ExtractSliceOp). diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" @@ -53,13 +54,30 @@ continue; if (state.isInPlace(opOperand)) continue; + + // Is the result yielded from a block? SmallVector aliasingOpResults = state.getAliasingOpResult(opOperand); bool escape = llvm::any_of( aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); - Value copy = rewriter.create( - op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); - rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); }); + // Create alloc_tensor op. + Value alloc; + if (state.canOmitTensorCopy(opOperand)) { + // No copy needed: Just allocate. + SmallVector dynamicSizes; + for (int64_t i = 0; i < tensorType.getRank(); ++i) + if (tensorType.isDynamicDim(i)) + dynamicSizes.push_back( + rewriter.create(op->getLoc(), opOperand.get(), i)); + alloc = rewriter.create( + op->getLoc(), tensorType, dynamicSizes, /*copy=*/Value(), escape); + } else { + // Allocate and copy. + alloc = rewriter.create(op->getLoc(), tensorType, + /*dynamicSizes=*/ValueRange(), + opOperand.get(), escape); + } + rewriter.updateRootInPlace(op, [&]() { opOperand.set(alloc); }); } return success(); } @@ -270,6 +288,26 @@ fn(*this); } +bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const { + // Do not copy if the tensor has undefined contents. + if (hasUndefinedContents(&opOperand)) + return true; + + // Do not copy if the buffer of the tensor is entirely overwritten. + if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) + return true; + + // Do not copy if the tensor is never read. + SmallVector aliasingOpResults = getAliasingOpResult(opOperand); + if (!bufferizesToMemoryRead(opOperand) && + llvm::none_of(aliasingOpResults, + [&](OpResult opResult) { return isValueRead(opResult); })) + return true; + + // Default: Cannot omit the copy. + return false; +} + // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -312,6 +312,12 @@ loc, resultType, dynamicDims, /*copy=*/Value(), /*escape=*/isYielded); + // Shortcut: Step 3 can be omitted when the tensor contents are not read. + if (state.canOmitTensorCopy(extractSliceOp->getOpOperand(0))) { + rewriter.replaceOp(op, alloc); + return success(); + } + // Step 3: Create InsertSliceOp. SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; diff --git a/mlir/test/Dialect/Bufferization/Transforms/tensor-copy-insertion.mlir b/mlir/test/Dialect/Bufferization/Transforms/tensor-copy-insertion.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/tensor-copy-insertion.mlir @@ -25,3 +25,54 @@ %0 = bufferization.alloc_tensor() : tensor<5xf32> return %0 : tensor<5xf32> } + +// ----- + +// CHECK-LABEL: func @do_not_copy_undefined_tensor +func.func @do_not_copy_undefined_tensor(%f: f32, %idx: index) + -> (tensor<5xf32>, tensor<5xf32>) +{ + // CHECK: bufferization.alloc_tensor() {escape = false} : tensor<5xf32> + // The second alloc_tensor should not have a copy operand. + // CHECK: bufferization.alloc_tensor() {escape = false} : tensor<5xf32> + %0 = bufferization.alloc_tensor() : tensor<5xf32> + %1 = tensor.insert %f into %0[%idx] : tensor<5xf32> + return %0, %1 : tensor<5xf32>, tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @do_not_copy_when_overwritten +func.func @do_not_copy_when_overwritten(%t: tensor<5xf32>, %f: f32) + -> (tensor<5xf32>, tensor<5xf32>) +{ + // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() {escape = false} : tensor<5xf32> + // CHECK: linalg.generic {{.*}} outs(%[[alloc]] : tensor<5xf32>) + %r = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + outs(%t : tensor<5xf32>) { + ^bb0(%arg0 : f32) : + linalg.yield %f : f32 + } -> tensor<5xf32> + return %t, %r : tensor<5xf32>, tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: func @do_not_copy_when_result_not_read +func.func @do_not_copy_when_result_not_read(%t: tensor<5xf32>, %f: f32) + -> (tensor<3xf32>) +{ + %0 = tensor.extract_slice %t[0][3][1] : tensor<5xf32> to tensor<3xf32> + // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() {escape = false} : tensor<3xf32> + // CHECK: linalg.generic {{.*}} outs(%[[alloc]] : tensor<3xf32>) + %r = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + outs(%0 : tensor<3xf32>) { + ^bb0(%arg0 : f32) : + linalg.yield %f : f32 + } -> tensor<3xf32> + return %r : tensor<3xf32> +}