diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -84,9 +84,10 @@ Operation *opToBufferize, DenseSet &usesRead, DenseSet &usesWrite, const DominanceInfo &domInfo) const; - /// Return true if bufferizing `opResult` inplace would create a write to a - /// non-writable buffer. - bool wouldCreateWriteToNonWritableBuffer(OpResult opResult) const; + /// Return true if bufferizing `opOperand` inplace with `opResult` would + /// create a write to a non-writable buffer. + bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, + OpResult opResult) const; /// Assume that result bufferizes in-place with one of the operation's /// operands. Return true if it is possible to find an inplace write W (resp. @@ -109,7 +110,7 @@ /// read(%0) /// ``` bool - wouldCreateReadAfterWriteInterference(OpResult result, + wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const; /// Return true if `v1` and `v2` bufferize to equivalent buffers. @@ -230,7 +231,7 @@ llvm::EquivalenceClasses equivalentInfo; }; -/// Analyze the `ops` to determine which OpResults are inplaceable: +/// Analyze the `ops` to determine which OpResults are inplaceable. LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -878,14 +878,10 @@ /// C interleaved between W and R (i.e. W -> C -> R where -> denotes /// dominance). bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( - OpResult result, const DominanceInfo &domInfo) const { - Optional maybeAliasingOperand = getAliasingOpOperand(result); - if (!maybeAliasingOperand) - return false; - + OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const { Operation *opToBufferize = result.getDefiningOp(); Value opResult = result; - Value opOperand = (*maybeAliasingOperand)->get(); + Value opOperand = operand.get(); LDBG("----Start wouldCreateReadAfterWriteInterference\n"); LDBG("--------consider all aliases to root read: " @@ -937,16 +933,16 @@ getAliasingReads(usesRead, opOperand); getAliasingInplaceWrites(usesWrite, opResult); // Additionally, `result` is not yet bufferized and we need to check for - // interferences as if it were bufferized inplace: add `maybeAliasingOperand` - // if it is a write. This handles the case: + // interferences as if it were bufferized inplace: add `operand` if it is a + // write. This handles the case: // // ``` // %0 = op_to_bufferize_maybe_inplace(%1) // %2 = some_alias(%1) // read(%2) // ``` - if (bufferizesToMemoryWrite(**maybeAliasingOperand)) - usesWrite.insert(*maybeAliasingOperand); + if (bufferizesToMemoryWrite(operand)) + usesWrite.insert(&operand); if (wouldCreateReadAfterWriteInterference(opToBufferize, usesRead, usesWrite, domInfo)) return true; @@ -972,26 +968,22 @@ usesWrite, domInfo); } -/// Return true if bufferizing `opResult` inplace would create a write to a -/// non-writable buffer. +/// Return true if bufferizing `opOperand` inplace with `opResult` would create +/// a write to a non-writable buffer. bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( - OpResult opResult) const { - Optional maybeAliasingOperand = getAliasingOpOperand(opResult); - if (!maybeAliasingOperand || !*maybeAliasingOperand) - return false; - + OpOperand &opOperand, OpResult opResult) const { // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or // 2. A constant op. bool nonWriteable = aliasesNonWritableBuffer(opResult) || - aliasesNonWritableBuffer((*maybeAliasingOperand)->get()); + aliasesNonWritableBuffer(opOperand.get()); if (!nonWriteable) return false; // This is a problem only if the buffer is written to via some alias. bool hasWrite = aliasesInPlaceWrite(opResult) || - aliasesInPlaceWrite((*maybeAliasingOperand)->get()) || - bufferizesToMemoryWrite(**maybeAliasingOperand); + aliasesInPlaceWrite(opOperand.get()) || + bufferizesToMemoryWrite(opOperand); if (!hasWrite) return false; @@ -2254,8 +2246,8 @@ << printValueInfo(result) << '\n'); bool foundInterference = - aliasInfo.wouldCreateWriteToNonWritableBuffer(result) || - aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); + aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) || + aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); @@ -2303,12 +2295,9 @@ return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); } -/// Analyze the `ops` to determine which OpResults are inplaceable: -/// 1. First, analyze InsertSliceOp greedily: we almost never want to -/// bufferize the tensor "inserted into" to become out-of-place. -/// 2. Walk the other ops in reverse. This is a good starter heuristic. -/// ExtractSliceOps are interleaved with other ops in traversal order. -/// +/// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in +/// reverse and bufferize ops greedily. This is a good starter heuristic. +/// ExtractSliceOps are interleaved with other ops in traversal order. LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { @@ -2323,13 +2312,12 @@ // to properly capture aliases. // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) { + if (auto extractSliceOp = dyn_cast(op)) if (failed( bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) return failure(); - continue; - } } + return success(); }