diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -86,6 +86,9 @@ using namespace comprehensive_bufferize; using namespace mlir::bufferization; +/// A mapping of FuncOps to their callers. +using FuncCallerMap = DenseMap>; + namespace { /// The state of analysis of a FuncOp. enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; @@ -128,12 +131,6 @@ /// analyzed. DenseMap analyzedFuncOps; - /// A list of functions in the order in which they are analyzed + bufferized. - SmallVector orderedFuncOps; - - /// A mapping of FuncOps to their callers. - DenseMap> callerMap; - /// This function is called right before analyzing the given FuncOp. It /// initializes the data structures for the FuncOp in this state object. void startFunctionAnalysis(FuncOp funcOp) { @@ -570,7 +567,7 @@ static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl &orderedFuncOps, - DenseMap> &callerMap) { + FuncCallerMap &callerMap) { // For each FuncOp, the set of functions called by it (i.e. the union of // symbols of all nested CallOpInterfaceOp). DenseMap> calledBy; @@ -619,9 +616,8 @@ return success(); } -static void -foreachCaller(const DenseMap> &callerMap, - FuncOp callee, llvm::function_ref doit) { +static void foreachCaller(const FuncCallerMap &callerMap, FuncOp callee, + llvm::function_ref doit) { auto itCallers = callerMap.find(callee); if (itCallers == callerMap.end()) return; @@ -1069,8 +1065,13 @@ FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); - if (failed(getFuncOpsOrderedByCalls(moduleOp, funcState.orderedFuncOps, - funcState.callerMap))) + // A list of functions in the order in which they are analyzed + bufferized. + SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; + + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); // Collect bbArg/return value information after the analysis. @@ -1078,7 +1079,7 @@ options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); // Analyze ops. - for (FuncOp funcOp : funcState.orderedFuncOps) { + for (FuncOp funcOp : orderedFuncOps) { // No body => no analysis. if (funcOp.getBody().empty()) continue; @@ -1105,7 +1106,7 @@ return success(); // Bufferize functions. - for (FuncOp funcOp : funcState.orderedFuncOps) { + for (FuncOp funcOp : orderedFuncOps) { // No body => no analysis. if (!funcOp.getBody().empty()) if (failed(bufferizeOp(funcOp, bufferizationState))) @@ -1118,7 +1119,7 @@ } // Check result. - for (FuncOp funcOp : funcState.orderedFuncOps) { + for (FuncOp funcOp : orderedFuncOps) { if (!options.allowReturnAllocs && llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { return t.isa();