diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -89,6 +89,25 @@ }); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpResult must bufferize in-place with its + corresponding aliasing OpOperand. Alias sets and inplace attributes + will be set up accordingly before making any other bufferization + decisions. This method will never be called on OpResults that do not + have a tensor type. + + Note: This method may not return `true` if the given OpResult does not + have an aliasing OpOperand. + }], + /*retType=*/"bool", + /*methodName=*/"mustBufferizeInPlace", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Return the OpResult that aliases with a given OpOperand when diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -538,18 +538,20 @@ createAliasInfoEntry(bbArg); }); - // The return value of an scf::IfOp aliases with both yield values. - rootOp->walk([&](scf::IfOp ifOp) { - if (ifOp->getNumResults() > 0) { - for (auto it : llvm::zip(ifOp.thenYield().results(), - ifOp.elseYield().results(), ifOp.results())) { - aliasInfo.unionSets(std::get<0>(it), std::get<1>(it)); - aliasInfo.unionSets(std::get<0>(it), std::get<2>(it)); - } - - // scf::IfOp always bufferizes in-place. - for (OpResult opResult : ifOp->getResults()) - setInPlaceOpResult(opResult, InPlaceSpec::True); + // Set up alias sets for OpResults that must bufferize in-place. This should + // be done before making any other bufferization decisions. + rootOp->walk([&](BufferizableOpInterface bufferizableOp) { + for (OpResult opResult : bufferizableOp->getOpResults()) { + if (opResult.getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opResult)) { + SmallVector operands = + bufferizableOp.getAliasingOpOperand(opResult); + assert(!operands.empty() && + "expected that OpResult has aliasing OpOperand"); + for (OpOperand *operand : operands) + aliasInfo.unionSets(operand->get(), opResult); + setInPlaceOpResult(opResult, InPlaceSpec::True); + } } }); } @@ -2753,6 +2755,12 @@ return true; } + bool mustBufferizeInPlace(Operation *op, OpResult opResult) const { + // IfOp results always bufferize in-place. Since they have no OpOperands, + // they are mostly ignored by the analysis once alias sets are set up. + return true; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,