diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -337,6 +337,9 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0; + /// Return `true` if the given tensor has undefined contents. + virtual bool hasUndefinedContents(OpOperand *opOperand) const = 0; + /// Return true if the given tensor (or an aliasing tensor) is yielded from /// the containing block. Also include all aliasing tensors in the same block. /// @@ -410,6 +413,9 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + /// Return `true` if the given tensor has undefined contents. + bool hasUndefinedContents(OpOperand *opOperand) const override; + /// Return true if the given tensor (or an aliasing tensor) is yielded from /// the containing block. Also include all aliasing tensors in the same block. bool isTensorYielded(Value tensor) const override; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -166,10 +166,17 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + /// Return `true` if the given tensor has undefined contents. + bool hasUndefinedContents(OpOperand *opOperand) const override; + /// Return true if the given tensor (or an aliasing tensor) is yielded from /// the containing block. Also include all aliasing tensors in the same block. bool isTensorYielded(Value tensor) const override; + /// Find all tensor values in the given operation that have undefined contents + /// and store them in `undefinedTensorUses`. + void gatherUndefinedTensorUses(Operation *op); + /// Find all tensors that are yielded/returned from a block and store them in /// `yieldedTensors`. Also include all aliasing tensors in the same block. void gatherYieldedTensors(Operation *op); @@ -182,6 +189,9 @@ /// A set of all tensors (and maybe aliasing tensors) that yielded from a /// block. DenseSet yieldedTensors; + + /// A set of uses of tensors that have undefined contents. + DenseSet undefinedTensorUses; }; /// Analyze `op` and its nested ops. Bufferization decisions are stored in diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -303,18 +303,8 @@ rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); if (failed(resultBuffer)) return failure(); - // Do not copy if the last preceding writes of `operand` are ops that do - // not write (skipping ops that merely create aliases). E.g., InitTensorOp. - // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA - // use-def chain, it returns that value, regardless of whether it is a - // memory write or not. - SetVector lastWrites = analysisState.findLastPrecedingWrite(operand); - if (llvm::none_of(lastWrites, [&](Value lastWrite) { - if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) - return bufferizableOp.isMemoryWrite(lastWrite.cast(), - analysisState); - return true; - })) + // Do not copy the buffer if its contents are undefined. + if (analysisState.hasUndefinedContents(&opOperand)) return resultBuffer; // Do not copy if the copied data is never read. if (!aliasingOpResults.empty() && @@ -407,6 +397,12 @@ return false; } +/// Return `true` if the given tensor has undefined contents. +bool AlwaysCopyAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { + // There is no analysis, so the conservative answer is "false". + return false; +} + /// Return true if the given tensor (or an aliasing tensor) is yielded from /// the containing block. Also include all aliasing tensors in the same block. bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -249,6 +249,43 @@ }); } +void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { + op->walk([&](Operation *op) { + // Skip unknown ops. + auto bufferizableOp = getOptions().dynCastBufferizableOp(op); + if (!bufferizableOp) + return WalkResult::skip(); + + // Check all tensor OpResults. + for (OpResult opResult : op->getOpResults()) { + if (!opResult.getType().isa()) + continue; + + // If there is no preceding memory write, the tensor contents are + // undefined. + // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA + // use-def chain, it returns that value, regardless of whether it is a + // memory write or not. + SetVector lastWrites = findLastPrecedingWrite(opResult); + bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) { + if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite)) + return bufferizableOp.isMemoryWrite(lastWrite.cast(), + *this); + return true; + }); + if (isUndefined) + for (OpOperand &use : opResult.getUses()) + undefinedTensorUses.insert(&use); + } + + return WalkResult::advance(); + }); +} + +bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const { + return undefinedTensorUses.contains(opOperand); +} + bool OneShotAnalysisState::isTensorYielded(Value tensor) const { return yieldedTensors.contains(tensor); } @@ -915,8 +952,9 @@ failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)); } - // Gather all yielded tensors. + // Gather some extra analysis data. state.gatherYieldedTensors(op); + state.gatherUndefinedTensorUses(op); // Analysis verification: After setting up alias/equivalence sets, each op // can check for expected invariants/limitations and fail the analysis if diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -359,3 +359,19 @@ return %3 : tensor } +// ----- + +// CHECK-LABEL: func @do_not_copy_init_tensors( +func.func @do_not_copy_init_tensors(%f1: f32, %f2: f32, %idx: index) + -> (tensor<5xf32>, tensor<5xf32>) +{ + // CHECK: memref.alloc + // CHECK: memref.alloc + // CHECK-NOT: copy + // CHECK: memref.store + // CHECK: memref.store + %0 = linalg.init_tensor [5] : tensor<5xf32> + %1 = tensor.insert %f1 into %0[%idx] : tensor<5xf32> + %2 = tensor.insert %f2 into %0[%idx] : tensor<5xf32> + return %1, %2 : tensor<5xf32>, tensor<5xf32> +}