diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -544,6 +544,10 @@ equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it)); equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it)); } + + // scf::IfOp always bufferizes in-place. + for (OpResult opResult : ifOp->getResults()) + setInPlaceOpResult(opResult, InPlaceSpec::True); } }); } @@ -716,20 +720,19 @@ return true; auto bufferizableOp = dyn_cast(op); if (!bufferizableOp) + // Unknown op: Assume that it is writing. return true; - if (isa(op)) - return true; - - SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(value.cast()); - assert(opOperands.size() <= 1 && - "op with multiple aliasing OpOperands not expected"); - - if (opOperands.empty()) + // Check if any of the aliasing OpOperands is writing. + if (llvm::any_of( + bufferizableOp.getAliasingOpOperand(value.cast()), + [](OpOperand *operand) { + return bufferizesToMemoryWrite(*operand); + })) return true; - - return bufferizesToMemoryWrite(*opOperands.front()); + // Not a write, continue iterating... + return false; }); + assert(result.size() == 1 && "expected exactly one result"); return result.front(); } @@ -1259,10 +1262,10 @@ } // If bufferizing out-of-place, allocate a new buffer. - bool needCopy = - getInPlace(result) != InPlaceSpec::True && !isa(op); + bool needCopy = getInPlace(result) != InPlaceSpec::True; if (needCopy) { - // Ops such as scf::IfOp can currently not bufferize out-of-place. + // Ops with multiple aliasing operands can currently not bufferize + // out-of-place. assert( aliasingOperands.size() == 1 && "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); @@ -2769,9 +2772,11 @@ : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { + // IfOps do not have tensor OpOperands. The yielded value can be any SSA + // value that is in scope. To allow for use-def chain traversal through + // IfOps in the analysis, both corresponding yield values from the then/else + // branches are considered to be aliasing with the result. auto ifOp = cast(op); - // Either one of the corresponding yield values from the then/else branches - // may alias with the result. size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); return {&ifOp.thenYield()->getOpOperand(resultNum), @@ -2885,13 +2890,21 @@ } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; + // scf::YieldOp does usually not bufferize to a memory write. However, there + // is a special rule for IfOps. The aliasing OpOperand of an IfOp is the + // YieldOp's OpOperand. For analysis purposes, IfOps are considered to be + // writing ops, so this method should return `true`. + return isa(op->getParentOp()); } OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,