diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -93,27 +93,48 @@ /// Extra analysis state that is required for bufferization of function /// boundaries. struct ModuleAnalysisState : public DialectAnalysisState { + /// A set of block argument indices. + using BbArgIndexSet = DenseSet; + + /// A mapping of indices to indices. + using IndexMapping = DenseMap; + /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. - DenseMap> equivalentFuncArgs; + DenseMap equivalentFuncArgs; /// A set of all read BlockArguments of FuncOps. - // Note: BlockArgument knows about its owner, so we do not need to store - // FuncOps here. - DenseSet readBbArgs; + DenseMap readBbArgs; /// A set of all written-to BlockArguments of FuncOps. - DenseSet writtenBbArgs; + DenseMap writtenBbArgs; /// Keep track of which FuncOps are fully analyzed or currently being /// analyzed. DenseMap analyzedFuncOps; - // A list of functions in the order in which they are analyzed + bufferized. + /// A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; - // A mapping of FuncOps to their callers. + /// A mapping of FuncOps to their callers. DenseMap> callerMap; + + /// This function is called right before analyzing the given FuncOp. It + /// initializes the data structures for the FuncOp in this state object. + void startFunctionAnalysis(FuncOp funcOp) { + analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; + auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); + auto createdRead = readBbArgs.try_emplace(funcOp, BbArgIndexSet()); + auto createdWritten = writtenBbArgs.try_emplace(funcOp, BbArgIndexSet()); + (void)createdEquiv; + (void)createdRead; + (void)createdWritten; +#ifndef NDEBUG + assert(createdEquiv.second && "equivalence info exists already"); + assert(createdRead.second && "bbarg access info exists already"); + assert(createdWritten.second && "bbarg access info exists already"); +#endif // NDEBUG + } }; } // namespace @@ -267,8 +288,8 @@ // read + written. if (funcOp.getBody().empty()) { for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs.insert(bbArg); - moduleState.writtenBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -282,9 +303,9 @@ if (state.getOptions().testAnalysisOnly) annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); if (isRead) - moduleState.readBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); if (isWritten) - moduleState.writtenBbArgs.insert(bbArg); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -704,8 +725,8 @@ // FuncOp not analyzed yet. Assume that OpOperand is read. return true; - return moduleState.readBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + return moduleState.readBbArgs.lookup(funcOp).contains( + opOperand.getOperandNumber()); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -719,8 +740,8 @@ // FuncOp not analyzed yet. Assume that OpOperand is written. return true; - return moduleState.writtenBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + return moduleState.writtenBbArgs.lookup(funcOp).contains( + opOperand.getOperandNumber()); } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, @@ -1010,7 +1031,7 @@ continue; // Now analyzing function. - moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; + moduleState.startFunctionAnalysis(funcOp); // Analyze funcOp. if (failed(analyzeOp(funcOp, analysisState)))