diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -417,7 +417,7 @@ private: friend LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, - BufferizationState &state); + BufferizationState &state, bool runAnalysis); friend LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -16,14 +16,23 @@ namespace linalg { namespace comprehensive_bufferize { +class BufferizationAliasInfo; struct BufferizationOptions; class BufferizationState; +/// Analyze `op` and its nested ops. Bufferization decisions are stored in +/// `state` and `aliasInfo`. +LogicalResult analyzeOp(Operation *op, const BufferizationOptions &options, + BufferizationState &state, + BufferizationAliasInfo &aliasInfo); + /// Bufferize the given operation. Reuses an existing BufferizationState object. +/// If `runAnalysis` is set to false, all OpOperands bufferize out-of-place. /// This function overload is for internal usage only. LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, - BufferizationState &state); + BufferizationState &state, + bool runAnalysis = true); /// Bufferize the given operation. LogicalResult diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -621,13 +621,10 @@ return success(); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( +LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp( Operation *op, const BufferizationOptions &options, - BufferizationState &state) { - - IRRewriter rewriter(op->getContext()); + BufferizationState &state, BufferizationAliasInfo &aliasInfo) { DominanceInfo domInfo(op); - BufferizationAliasInfo &aliasInfo = state.aliasInfo; if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); @@ -650,10 +647,20 @@ } // Annotate operations if we only want to report the analysis. - if (options.testAnalysisOnly) { + if (options.testAnalysisOnly) annotateOpsWithBufferizationMarkers(op, aliasInfo, state); - return success(); - } + + return success(); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, + BufferizationState &state, bool runAnalysis) { + BufferizationAliasInfo &aliasInfo = state.aliasInfo; + + if (runAnalysis) + if (failed(analyzeOp(op, options, state, aliasInfo))) + return failure(); // Bufferize the op and its nested ops. OwningRewritePatternList patterns(op->getContext()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -844,6 +844,8 @@ // inplace. Therefore, we just bufferize funcOp as if none of its results were // inplaceable, detect which operands are cloned internally and decide what to // do at call sites. + + // Analyze ops. for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. if (funcOp.body().empty()) @@ -856,8 +858,8 @@ // Gather equivalence info for CallOps. equivalenceAnalysis(funcOp, aliasInfo, moduleState); - // Analyze and bufferize funcOp. - if (failed(runComprehensiveBufferize(funcOp, *options, state))) + // Analyze funcOp. + if (failed(analyzeOp(funcOp, *options, state, aliasInfo))) return failure(); // Add annotations to function arguments. @@ -868,6 +870,18 @@ if (options->testAnalysisOnly) return success(); + // Bufferize function bodies. + for (FuncOp funcOp : moduleState.orderedFuncOps) { + // No body => no analysis. + if (funcOp.body().empty()) + continue; + + if (failed(runComprehensiveBufferize(funcOp, *options, state, + /*runAnalysis=*/false))) + return failure(); + } + + // Bufferize function boundaries. for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated.