diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -318,6 +318,11 @@ /// an alias. Return false if the op is not bufferizable. bool bufferizesToAliasOnly(OpOperand &opOperand) const; + /// Compute all OpOperands that reading `value` and add them to `result`. + /// Also takes into account ops that create an alias but do not read by + /// themselves (e.g., ExtractSliceOp). + void getValueReads(DenseSet &result, Value value) const; + /// Return true if the given value is read by an op that bufferizes to a /// memory read. Also takes into account ops that create an alias but do not /// read by themselves (e.g., ExtractSliceOp). diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -251,11 +251,11 @@ return false; } -/// Return true if the given value is read by an op that bufferizes to a memory -/// read. Also takes into account ops that create an alias but do not read by +/// Compute all OpOperands that reading `value` and add them to `result`. +/// Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). -bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( - Value value) const { +void mlir::linalg::comprehensive_bufferize::BufferizationState::getValueReads( + DenseSet &result, Value value) const { assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) @@ -268,10 +268,19 @@ for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) - return true; + result.insert(uMaybeReading); } +} - return false; +/// Return true if the given value is read by an op that bufferizes to a memory +/// read. Also takes into account ops that create an alias but do not read by +/// themselves (e.g., ExtractSliceOp). +bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( + Value value) const { + assert(value.getType().isa() && "expected TensorType"); + DenseSet readingOperands; + getValueReads(readingOperands, value); + return !readingOperands.empty(); } // Starting from `value`, follow the use-def chain in reverse, always selecting diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -337,10 +337,22 @@ // Helper function to iterate on aliases of `root` and capture the reads. auto getAliasingReads = [&](DenseSet &res, Value root) { aliasInfo.applyOnAliases(root, [&](Value alias) { - for (auto &use : alias.getUses()) - // Read to a value that aliases root. - if (state.bufferizesToMemoryRead(use)) + for (auto &use : alias.getUses()) { + if (state.bufferizesToMemoryRead(use)) { + // Read to a value that aliases root. res.insert(&use); + continue; + } + if (state.bufferizesToMemoryWrite(use)) + // Tensor not read but written. The buffer is completely overwritten. + continue; + + OpResult opResult = state.getAliasingOpResult(use); + if (!opResult) + continue; + // Result may be read. + state.getValueReads(res, opResult); + } }); };