diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -24,15 +24,11 @@ /// `state`. LogicalResult analyzeOp(Operation *op, BufferizationState &state); -/// Bufferize the given operation. Reuses an existing BufferizationState object. -/// If `runAnalysis` is set to false, all OpOperands bufferize out-of-place. -/// This function overload is for internal usage only. -LogicalResult runComprehensiveBufferize(Operation *op, - const BufferizationOptions &options, - BufferizationState &state, - bool runAnalysis = true); - -/// Bufferize the given operation. +/// Bufferize `op` and its nested ops. Bufferization decisions are stored in +/// `state`. +LogicalResult bufferizeOp(Operation *op, BufferizationState &state); + +/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization LogicalResult runComprehensiveBufferize(Operation *op, std::unique_ptr options); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -22,8 +22,9 @@ struct BufferizationOptions; -/// Bufferize the given module. This bufferizations performs a simple function -/// call analysis to determine which function arguments are inplaceable. +/// Run Module Bufferization on the given module. Performs a simple function +/// call analysis to determine which function arguments are inplaceable. Then +/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization. LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, std::unique_ptr options); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -555,12 +555,6 @@ }); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - Operation *op, std::unique_ptr options) { - BufferizationState state(op, *options); - return runComprehensiveBufferize(op, *options, state); -} - /// Rewrite pattern that bufferizes bufferizable ops. struct BufferizationPattern : public OpInterfaceRewritePattern { @@ -652,18 +646,26 @@ return success(); } -LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - Operation *op, const BufferizationOptions &options, - BufferizationState &state, bool runAnalysis) { - if (runAnalysis) - if (failed(analyzeOp(op, state))) - return failure(); - +LogicalResult +mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op, + BufferizationState &state) { // Bufferize the op and its nested ops. OwningRewritePatternList patterns(op->getContext()); patterns.add(op->getContext(), state); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return failure(); - return checkBufferizationResult(op, options); + return checkBufferizationResult(op, state.getOptions()); +} + +LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( + Operation *op, std::unique_ptr options) { + BufferizationState state(op, *options); + if (failed(analyzeOp(op, state))) + return failure(); + if (options->testAnalysisOnly) + return success(); + if (failed(bufferizeOp(op, state))) + return failure(); + return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -886,8 +886,7 @@ if (funcOp.body().empty()) continue; - if (failed(runComprehensiveBufferize(funcOp, *options, state, - /*runAnalysis=*/false))) + if (failed(bufferizeOp(funcOp, state))) return failure(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -117,6 +117,9 @@ if (failed(runComprehensiveBufferize(op, std::move(options)))) return; + if (testAnalysisOnly) + return; + OpPassManager cleanupPipeline("builtin.func"); cleanupPipeline.addPass(createCanonicalizerPass()); cleanupPipeline.addPass(createCSEPass());