diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -516,14 +516,6 @@ return op->getResult(0); } -/// Return the OpResult that may bufferize into the same buffer as `opOperand` -/// when the op is bufferized inplace. -/// Return null if no such result exists. -static OpResult getInplaceableOpResult(tensor::CastOp op, - OpOperand &opOperand) { - return op->getResult(0); -} - /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// The inplace analysis uses this information along with interfering read @@ -534,16 +526,16 @@ // clang-format off // Ops that perform destructive updates on operand(s) to produce // result(s). - .Case( [&](auto op) { return getInplaceableOpResult(op, opOperand); }) - // ExtractSliceOp is special, when bufferized inplace it just returns an - // alias to its operand. Its result is never inplaceable on its operand. - .Case([&](ExtractSliceOp op) { return OpResult(); }) + // Some ops just return an alias to an operand when bufferized inplace. + // Such OpResults are never inplaceable on an OpOperand. + .Case( + [] (auto op) { return OpResult(); }) // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -572,9 +564,9 @@ if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) return SmallVector(); TypeSwitch(result.getDefiningOp()) - .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); }) - .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); }) .Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); }) + .Case( + [&](auto op) { r.push_back(&op->getOpOperand(0)); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. @@ -606,7 +598,15 @@ /// If the an ExtractSliceOp is bufferized in-place, the source operand will /// alias with the result. static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) { - if (op.source() == opOperand.get()) + if (&op->getOpOperand(0) == &opOperand) + return op->getResult(0); + return OpResult(); +} + +/// If the a tensor::CastOp is bufferized in-place, the source operand will +/// alias with the result. +static OpResult getAliasingOpResult(tensor::CastOp op, OpOperand &opOperand) { + if (&op->getOpOperand(0) == &opOperand) return op->getResult(0); return OpResult(); } @@ -616,11 +616,11 @@ /// TODO: in the future this may need to evolve towards a list of OpResult. static OpResult getAliasingOpResult(OpOperand &opOperand) { return TypeSwitch(opOperand.getOwner()) - // ExtractSliceOp is different: its result is not inplaceable on op.source - // but when bufferized inplace, the result is an aliasing subregion of - // op.source. - .Case( - [&](ExtractSliceOp op) { return getAliasingOpResult(op, opOperand); }) + // Some ops are different: Their result is not inplaceable on an OpOperand + // but when bufferized inplace, their result is aliasing (a subregion of) + // an OpOperand. + .Case( + [&](auto op) { return getAliasingOpResult(op, opOperand); }) // All other ops, return the result of `getInplaceableOpResult`. .Default( [&](Operation *op) { return getInplaceableOpResult(opOperand); }); @@ -639,11 +639,9 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); - // Skip over all ExtractSliceOps. These do not read by themselves but just - // add a new alias. - if (auto extractSliceOp = - dyn_cast(uMaybeReading->getOwner())) - for (OpOperand &use : extractSliceOp.result().getUses()) + // Skip over all ops that create an alias but do not read. + if (isa(uMaybeReading->getOwner())) + for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; @@ -658,9 +656,9 @@ // it. Conservatively return true. if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return true; - // ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses + // Some ops alone do not bufferize to a memory read, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (isa(opOperand.getOwner())) return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. @@ -690,9 +688,9 @@ // These terminators are not writes. if (isa(opOperand.getOwner())) return false; - // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses + // Some ops alone do not bufferize to a memory write, but one of their uses // may. - if (isa(opOperand.getOwner())) + if (isa(opOperand.getOwner())) return false; // CallOpInterface alone doesn't bufferize to a memory write, one of the uses // of the matching bbArg may. It is the responsibility of the caller to @@ -2320,27 +2318,28 @@ return success(); } +/// This analysis function is used for ops where the first OpOperand aliases +/// with the first OpResult, without creating a read or write. There are a few +/// ops besides ExtractSliceOp that have such semantics. /// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== +/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: /// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This +/// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This /// cannot change the flow of information for either the source or the /// result buffers. /// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. +/// When bufferized inplace, an ExtractSliceOp does not by itself create any +/// read or write from memory. Instead, it has the effect of merging the alias +/// sets of the source and the result buffers. /// /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - return bufferizableInPlaceAnalysisImpl(extractSliceOp->getOpOperand(0), - extractSliceOp->getOpResult(0), - aliasInfo, domInfo); +bufferizableInPlaceAnalysisAliasOnlyOp(Operation *op, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { + return bufferizableInPlaceAnalysisImpl( + op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo); } /// Determine if `operand` can be bufferized in-place with one of the op's @@ -2377,14 +2376,11 @@ if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) + // Special logic to analyze ops who's OpResults are not inplaceable on an + // OpOperand but may create an alias. + if (isa(op)) if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) + bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo))) return failure(); } @@ -3053,7 +3049,8 @@ aliasInfo.createAliasInfoEntry(extractOp.result()); // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysis(extractOp, aliasInfo, domInfo))) + if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(extractOp, aliasInfo, + domInfo))) return WalkResult::interrupt(); // Advance to the next operation.