diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2369,6 +2369,9 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// +/// Determine if `operand` can be bufferized in-place with one of the op's +/// results. If so, set InPlaceSpec::True on the result. Otherwise, set +/// InPlaceSpec::False on the result. /// /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. /// =========================================================== @@ -2384,57 +2387,22 @@ /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " - << printOperationInfo(extractSliceOp) << '\n'); - - // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writeable buffer. - bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); - - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); - else - LDBG("->bufferizes to writeable inplace buffer\n"); - - // In any of extractSliceOp.result's aliases, can we find 2 such that we hit - // an interfering write? - OpResult r = extractSliceOp->getResult(0); - OpOperand &s = extractSliceOp->getOpOperand(0); - bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || - aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(r); - else - aliasInfo.bufferizeInPlace(r, s); - - LDBG("Done inplace analysis for extract_slice\n"); - - return success(); -} - -/// Determine if `operand` can be bufferized in-place with one of the op's -/// results. If so, set InPlaceSpec::True on the result. Otherwise, set -/// InPlaceSpec::False on the result. -static LogicalResult bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - OpResult result = getInplaceableOpResult(operand); + Operation *op = operand.getOwner(); + bool isExtractSliceOp = false; + OpResult result; + if (isa(op) && (operand.getOperandNumber() == 0)) { + result = op->getOpResult(0); + isExtractSliceOp = true; + } else { + result = getInplaceableOpResult(operand); + } + if (!result) return success(); - Operation *op = result.getDefiningOp(); - assert(result && !isa(op) && - "expected OpResult not coming from a ExtractSliceOp"); - (void)op; - int64_t resultNumber = result.getResultNumber(); (void)resultNumber; LDBG('\n'); @@ -2448,7 +2416,9 @@ // 2. a constant op. // to be considered for inplace bufferization bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesNonWriteableBuffer(operand); + aliasInfo.aliasesNonWriteableBuffer(operand) && + (!isExtractSliceOp || aliasInfo.aliasesInPlaceWrite(result)); + if (wouldCreateAliasingWriteToNonWriteableBuffer) LDBG("->the corresponding buffer is not writeable\n"); else @@ -2461,10 +2431,10 @@ if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); else - // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support - // more cases on a per-need basis. aliasInfo.bufferizeInPlace( - result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); + result, operand, + isExtractSliceOp ? BufferizationAliasInfo::BufferRelation::None + : BufferizationAliasInfo::BufferRelation::Equivalent); LDBG("Done inplace analysis for result #" << resultNumber << '\n'); @@ -2496,24 +2466,11 @@ }); // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) { + for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) { - if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - return failure(); - continue; - } - } - LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); return success();