diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -15,8 +15,19 @@ class ModuleOp; namespace bufferization { +struct BufferizationState; +class OneShotAnalysisState; struct OneShotBufferizationOptions; +/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in +/// `state`. +LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state); + +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Whether buffer copies are needed or not is queried from the given state. +LogicalResult bufferizeModuleOp(ModuleOp moduleOp, + const OneShotAnalysisState &analysisState); + /// Run One-Shot Module Bufferization on the given module. Performs a simple /// function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -380,15 +380,15 @@ funcOp.setType(newFuncType); } -LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, OneShotBufferizationOptions options) { +LogicalResult +mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, + OneShotAnalysisState &state) { + OneShotBufferizationOptions options = + static_cast(state.getOptions()); assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); - OneShotAnalysisState analysisState(moduleOp, options); - BufferizationState bufferizationState(analysisState); - FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); - BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); + FuncAnalysisState &funcState = getFuncAnalysisState(state); + BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; @@ -412,12 +412,12 @@ equivalenceAnalysis(funcOp, aliasInfo, funcState); // Analyze funcOp. - if (failed(analyzeOp(funcOp, analysisState))) + if (failed(analyzeOp(funcOp, state))) return failure(); // Run some extra function analyses. - if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, analysisState)) || - failed(funcOpBbArgReadWriteAnalysis(funcOp, analysisState))) + if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || + failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) return failure(); // Mark op as fully analyzed. @@ -425,11 +425,29 @@ // Add annotations to function arguments. if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(funcOp, analysisState); + annotateOpsWithBufferizationMarkers(funcOp, state); } - if (options.testAnalysisOnly) - return success(); + return success(); +} + +LogicalResult mlir::bufferization::bufferizeModuleOp( + ModuleOp moduleOp, const OneShotAnalysisState &analysisState) { + auto const &options = static_cast( + analysisState.getOptions()); + assert(options.bufferizeFunctionBoundaries && + "expected that function boundary bufferization is activated"); + IRRewriter rewriter(moduleOp.getContext()); + BufferizationState bufferizationState(analysisState); + + // A list of functions in the order in which they are analyzed + bufferized. + SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; + + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + return failure(); // Bufferize functions. for (func::FuncOp funcOp : orderedFuncOps) { @@ -466,3 +484,17 @@ return success(); } + +LogicalResult mlir::bufferization::runOneShotModuleBufferize( + ModuleOp moduleOp, OneShotBufferizationOptions options) { + assert(options.bufferizeFunctionBoundaries && + "expected that function boundary bufferization is activated"); + OneShotAnalysisState analysisState(moduleOp, options); + if (failed(analyzeModuleOp(moduleOp, analysisState))) + return failure(); + if (options.testAnalysisOnly) + return success(); + if (failed(bufferizeModuleOp(moduleOp, analysisState))) + return failure(); + return success(); +}