diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -548,6 +548,10 @@ equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it)); equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it)); } + + // scf::IfOp always bufferizes in-place. + for (OpResult opResult : ifOp->getResults()) + setInPlaceOpResult(opResult, InPlaceSpec::True); } }); } @@ -734,10 +738,9 @@ auto bufferizableOp = dyn_cast(op); if (!bufferizableOp) return true; - if (isa(op)) - return true; return bufferizableOp.isMemoryWrite(value.cast()); }); + assert(result.size() == 1 && "expected exactly one result"); return result.front(); } @@ -1318,10 +1321,10 @@ } // If bufferizing out-of-place, allocate a new buffer. - bool needCopy = - getInPlace(result) != InPlaceSpec::True && !isa(op); + bool needCopy = getInPlace(result) != InPlaceSpec::True; if (needCopy) { - // Ops such as scf::IfOp can currently not bufferize out-of-place. + // Ops with multiple aliasing operands can currently not bufferize + // out-of-place. assert( aliasingOperands.size() == 1 && "ops with multiple aliasing OpOperands cannot bufferize out-of-place"); @@ -2724,15 +2727,47 @@ : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { + // IfOps do not have tensor OpOperands. The yielded value can be any SSA + // value that is in scope. To allow for use-def chain traversal through + // IfOps in the analysis, both corresponding yield values from the then/else + // branches are considered to be aliasing with the result. auto ifOp = cast(op); - // Either one of the corresponding yield values from the then/else branches - // may alias with the result. size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); return {&ifOp.thenYield()->getOpOperand(resultNum), &ifOp.elseYield()->getOpOperand(resultNum)}; } + // TODO: For better bufferization results, this could return `true` only if + // there is a memory write in one (or both) of the branches. Since this is not + // allowed at the moment, we should never encounter scf.ifs that yield + // unmodified tensors. Such scf.yield ops could just fold away. + bool isMemoryWrite(Operation *op, OpResult opResult) const { + // IfOp results are always considered memory writes in the analysis. This + // design decision simplifies the analysis considerably. E.g., consider the + // following test case: + // + // %0 = "some_writing_op" : tensor + // %r = scf.if %c -> (tensor) { + // scf.yield %0 + // } else { + // %1 = "another_writing_op"(%0) : tensor + // } + // "some_reading_op"(%r) + // + // "another_writing_op" in the above example should be able to bufferize + // inplace in the absence of another read of %0. However, if the scf.if op + // would not be considered a "write", the analysis would detect the + // following conflict: + // + // * read = some_reading_op + // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) + // * conflictingWrite = %1 + // + // For more details, check the "scf.IfOp" section of the design document. + return true; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, @@ -2847,6 +2882,10 @@ return OpResult(); } + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,