diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -494,7 +494,8 @@ static OpResult getInplaceableOpResult(VectorTransferOpInterface op, OpOperand &opOperand) { if (opOperand.get() != op.source() || - !op.source().getType().isa()) + !op.source().getType().isa() || + isa(op)) return OpResult(); return op->getResult(0); } @@ -2253,13 +2254,17 @@ return success(); } -/// Analyze the (opOperand, result) pair to determine whether the result can -/// be bufferized inPlace. If successful, InPlaceSpec::True is set for -/// `result`. Otherwise, InPlaceSpec::False is set for `result`. +/// Determine if `operand` can be bufferized in-place with one of the op's +/// results. If so, set InPlaceSpec::True on the result. Otherwise, set +/// InPlaceSpec::False on the result. static LogicalResult -bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, +bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { + OpResult result = getInplaceableOpResult(operand); + if (!result) + return success(); + Operation *op = result.getDefiningOp(); assert(result && !isa(op) && "expected OpResult not coming from a ExtractSliceOp"); @@ -2284,7 +2289,6 @@ else LDBG("->bufferizes to writable inplace buffer\n"); - assert(result == getInplaceableOpResult(operand)); bool foundInterference = wouldCreateAliasingWriteToNonWritableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); @@ -2313,22 +2317,21 @@ const DominanceInfo &domInfo) { // Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) { - for (OpOperand &opOperand : op->getOpOperands()) { - if (OpResult result = getInplaceableOpResult(opOperand)) - if (result.getType().isa() && - failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, - domInfo))) - op->emitWarning() << "Inplace analysis treated conservatively"; - } + for (OpOperand &opOperand : op->getOpOperands()) + if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) + return failure(); + // Special logic to analyze ExtractSliceOp. // Note that ExtractSliceOp analysis needs to be interleaved with other ops // to properly capture aliases. // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) + if (auto extractSliceOp = dyn_cast(op)) { if (failed( bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - op->emitWarning() << "Inplace analysis treated conservatively"; + return failure(); + continue; + } } return success(); }