diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -494,7 +494,8 @@ static OpResult getInplaceableOpResult(VectorTransferOpInterface op, OpOperand &opOperand) { if (opOperand.get() != op.source() || - !op.source().getType().isa()) + !op.source().getType().isa() || + isa(op)) return OpResult(); return op->getResult(0); } @@ -2423,13 +2424,17 @@ return success(); } -/// Analyze the (opOperand, result) pair to determine whether the result can -/// be bufferized inPlace. If successful, InPlaceSpec::True is set for -/// `result`. Otherwise, InPlaceSpec::False is set for `result`. +/// Determine if `operand` can be bufferized in-place with one of the op's +/// results. If so, set InPlaceSpec::True on the result. Otherwise, set +/// InPlaceSpec::False on the result. static LogicalResult -bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, +bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { + OpResult result = getInplaceableOpResult(operand); + if (!result) + return success(); + Operation *op = result.getDefiningOp(); assert(result && !isa(op) && "expected OpResult not coming from a ExtractSliceOp"); @@ -2454,7 +2459,6 @@ else LDBG("->bufferizes to writeable inplace buffer\n"); - assert(result == getInplaceableOpResult(operand)); bool foundInterference = wouldCreateAliasingWriteToNonWriteableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); @@ -2498,13 +2502,10 @@ // Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) { - for (OpOperand &opOperand : op->getOpOperands()) { - if (OpResult result = getInplaceableOpResult(opOperand)) - if (result.getType().isa() && - failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, - domInfo))) - return failure(); - } + for (OpOperand &opOperand : op->getOpOperands()) + if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) + return failure(); + // Special logic to analyze ExtractSliceOp. // Note that ExtractSliceOp analysis needs to be interleaved with other ops // to properly capture aliases.