diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -89,6 +89,25 @@ }); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpResult must bufferize in-place with its + corresponding aliasing OpOperand. Alias sets and inplace attributes + will be set up accordingly before making any other bufferization + decisions. This method will never be called on OpResults that do not + have a tensor type. + + Note: This method may not return `true` if the given OpResult does not + have an aliasing OpOperand. + }], + /*retType=*/"bool", + /*methodName=*/"mustBufferizeInPlace", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Return the OpResult that aliases with a given OpOperand when diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -538,18 +538,20 @@ createAliasInfoEntry(bbArg); }); - // The return value of an scf::IfOp aliases with both yield values. - rootOp->walk([&](scf::IfOp ifOp) { - if (ifOp->getNumResults() > 0) { - for (auto it : llvm::zip(ifOp.thenYield().results(), - ifOp.elseYield().results(), ifOp.results())) { - aliasInfo.unionSets(std::get<0>(it), std::get<1>(it)); - aliasInfo.unionSets(std::get<0>(it), std::get<2>(it)); - } - - // scf::IfOp always bufferizes in-place. - for (OpResult opResult : ifOp->getResults()) - setInPlaceOpResult(opResult, InPlaceSpec::True); + // Set up alias sets for OpResults that must bufferize in-place. This should + // be done before making any other bufferization decisions. + rootOp->walk([&](BufferizableOpInterface bufferizableOp) { + for (OpResult opResult : bufferizableOp->getOpResults()) { + if (opResult.getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opResult)) { + SmallVector operands = + bufferizableOp.getAliasingOpOperand(opResult); + assert(!operands.empty() && + "expected that OpResult has aliasing OpOperand"); + for (OpOperand *operand : operands) + aliasInfo.unionSets(operand->get(), opResult); + setInPlaceOpResult(opResult, InPlaceSpec::True); + } } }); } @@ -951,9 +953,14 @@ /// * However, adding an alias {%0, %t} would mean that the second /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp /// would no longer be reading the result of %1. +/// +/// If `checkConsistencyOnly` is true, this function checks if there is a +/// read-after-write conflict without bufferizing `operand` inplace. This would +/// indicate a problem with the current inplace bufferization decisions. bool wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo, - const BufferizationAliasInfo &aliasInfo) { + const BufferizationAliasInfo &aliasInfo, + bool checkConsistencyOnly = false) { #ifndef NDEBUG SmallVector opOperands = getAliasingOpOperand(result); assert(llvm::find(opOperands, &operand) != opOperands.end() && @@ -986,7 +993,7 @@ getAliasingReads(usesRead, result); getAliasingInplaceWrites(usesWrite, operand.get()); getAliasingInplaceWrites(usesWrite, result); - if (bufferizesToMemoryWrite(operand)) + if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo); @@ -2229,6 +2236,24 @@ }); } +/// Assert that the current bufferization decisions are consistent. +static void checkAliasInfoConsistency(FuncOp funcOp, + const DominanceInfo &domInfo, + const BufferizationAliasInfo &aliasInfo) { + funcOp.walk([&](Operation *op) { + if (auto bufferizableOp = dyn_cast(op)) + for (OpOperand &opOperand : op->getOpOperands()) + if (opOperand.get().getType().isa()) + if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) + // If this assertion fails, there is probably an inconsistent + // combination of "mustBufferizeInPlace" decisions. + assert(!wouldCreateReadAfterWriteInterference( + opOperand, opResult, domInfo, aliasInfo, + /*checkConsistencyOnly=*/true) && + "found read after write conflict before running analysis"); + }); +} + LogicalResult mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp, const BufferizationOptions &options) { @@ -2240,6 +2265,7 @@ DominanceInfo domInfo(moduleOp); BufferizationAliasInfo aliasInfo(moduleOp); + // Interestingly, all function args that are not visible outside of a module // can be fully bufferized inplace by guaranteeing the CallOp is bufferized // inplace. Therefore, we just bufferize funcOp as if none of its results were @@ -2260,6 +2286,10 @@ if (bbArg.getType().isa()) setInPlaceFuncArgument(bbArg); +#ifndef NDEBUG + checkAliasInfoConsistency(funcOp, domInfo, aliasInfo); +#endif // NDEBUG + // If the analysis fails, just return. if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo, options.analysisFuzzerSeed))) @@ -2778,6 +2808,12 @@ return true; } + bool mustBufferizeInPlace(Operation *op, OpResult opResult) const { + // IfOp results always bufferize in-place. Since they have no OpOperands, + // they are mostly ignored by the analysis once alias sets are set up. + return true; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,