diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2237,6 +2237,37 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// +/// Determine if `operand` can be bufferized in-place with `result`. If so, set +/// InPlaceSpec::True on the result. Otherwise, set InPlaceSpec::False on the +/// result. +static LogicalResult +bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { + assert(getAliasingOpOperand(result) == &operand && + "operand and result do not match"); + + int64_t resultNumber = result.getResultNumber(); + (void)resultNumber; + LDBG('\n'); + LDBG("Inplace analysis for <- #" << resultNumber << " -> #" + << operand.getOperandNumber() << " in " + << printValueInfo(result) << '\n'); + + bool foundInterference = + aliasInfo.wouldCreateWriteToNonWritableBuffer(result) || + aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); + + if (foundInterference) + aliasInfo.bufferizeOutOfPlace(result); + else + aliasInfo.bufferizeInPlace(result, operand); + + LDBG("Done inplace analysis for result #" << resultNumber << '\n'); + + return success(); +} + /// /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. /// =========================================================== @@ -2255,27 +2286,9 @@ bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " - << printOperationInfo(extractSliceOp) << '\n'); - - OpResult r = extractSliceOp->getResult(0); - OpOperand &s = extractSliceOp->getOpOperand(0); - bool foundInterference = - /* If `extractSliceOp` were to be bufferized inplace, it cannot end up - aliasing a write into a non-writable buffer.*/ - aliasInfo.wouldCreateWriteToNonWritableBuffer(r) || - /* In any of extractSliceOp.result's aliases, can we find 2 such that we - hit an interfering write? */ - aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(r); - else - aliasInfo.bufferizeInPlace(r, s); - - LDBG("Done inplace analysis for extract_slice\n"); - - return success(); + return bufferizableInPlaceAnalysisImpl(extractSliceOp->getOpOperand(0), + extractSliceOp->getOpResult(0), + aliasInfo, domInfo); } /// Determine if `operand` can be bufferized in-place with one of the op's @@ -2288,33 +2301,7 @@ OpResult result = getInplaceableOpResult(operand); if (!result) return success(); - - Operation *op = result.getDefiningOp(); - assert(result && !isa(op) && - "expected OpResult not coming from a ExtractSliceOp"); - (void)op; - - int64_t resultNumber = result.getResultNumber(); - (void)resultNumber; - LDBG('\n'); - LDBG("Inplace analysis for <- #" << resultNumber << " -> #" - << operand.getOperandNumber() << " in " - << printValueInfo(result) << '\n'); - - bool foundInterference = - aliasInfo.wouldCreateWriteToNonWritableBuffer(result) || - aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); - - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(result); - else - // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support - // more cases on a per-need basis. - aliasInfo.bufferizeInPlace(result, operand); - - LDBG("Done inplace analysis for result #" << resultNumber << '\n'); - - return success(); + return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); } /// Analyze the `ops` to determine which OpResults are inplaceable: