diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -16,14 +16,21 @@ namespace linalg { namespace comprehensive_bufferize { +class BufferizationAliasInfo; struct BufferizationOptions; class BufferizationState; +/// Analyze `op` and its nested ops. Bufferization decisions are stored in +/// `state`. +LogicalResult analyzeOp(Operation *op, BufferizationState &state); + /// Bufferize the given operation. Reuses an existing BufferizationState object. +/// If `runAnalysis` is set to false, all OpOperands bufferize out-of-place. /// This function overload is for internal usage only. LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, - BufferizationState &state); + BufferizationState &state, + bool runAnalysis = true); /// Bufferize the given operation. LogicalResult diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -618,13 +618,12 @@ return success(); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - Operation *op, const BufferizationOptions &options, - BufferizationState &state) { - - IRRewriter rewriter(op->getContext()); +LogicalResult +mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op, + BufferizationState &state) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); + const BufferizationOptions &options = state.getOptions(); if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); @@ -647,10 +646,18 @@ } // Annotate operations if we only want to report the analysis. - if (options.testAnalysisOnly) { + if (options.testAnalysisOnly) annotateOpsWithBufferizationMarkers(op, aliasInfo, state); - return success(); - } + + return success(); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, + BufferizationState &state, bool runAnalysis) { + if (runAnalysis) + if (failed(analyzeOp(op, state))) + return failure(); // Bufferize the op and its nested ops. OwningRewritePatternList patterns(op->getContext()); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -842,6 +842,8 @@ // inplace. Therefore, we just bufferize funcOp as if none of its results were // inplaceable, detect which operands are cloned internally and decide what to // do at call sites. + + // Analyze ops. for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. if (funcOp.body().empty()) @@ -854,8 +856,8 @@ // Gather equivalence info for CallOps. equivalenceAnalysis(funcOp, aliasInfo, moduleState); - // Analyze and bufferize funcOp. - if (failed(runComprehensiveBufferize(funcOp, *options, state))) + // Analyze funcOp. + if (failed(analyzeOp(funcOp, state))) return failure(); // Add annotations to function arguments. @@ -866,6 +868,18 @@ if (options->testAnalysisOnly) return success(); + // Bufferize function bodies. + for (FuncOp funcOp : moduleState.orderedFuncOps) { + // No body => no analysis. + if (funcOp.body().empty()) + continue; + + if (failed(runComprehensiveBufferize(funcOp, *options, state, + /*runAnalysis=*/false))) + return failure(); + } + + // Bufferize function boundaries. for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated.