diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -423,12 +423,11 @@ private: friend LogicalResult runComprehensiveBufferize(Operation *op, const BufferizationOptions &options, - BufferizationState &state, - const PostAnalysisStepList &extraSteps); + BufferizationState &state); friend LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, - const BufferizationOptions &options); + std::unique_ptr options); /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runComprehensiveBufferize` may access this object. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -18,17 +18,17 @@ struct BufferizationOptions; class BufferizationState; -struct PostAnalysisStep; /// Bufferize the given operation. Reuses an existing BufferizationState object. -LogicalResult runComprehensiveBufferize( - Operation *op, const BufferizationOptions &options, - BufferizationState &state, - const std::vector> &extraSteps); +/// This function overload is for internal usage only. +LogicalResult runComprehensiveBufferize(Operation *op, + const BufferizationOptions &options, + BufferizationState &state); /// Bufferize the given operation. -LogicalResult runComprehensiveBufferize(Operation *op, - const BufferizationOptions &options); +LogicalResult +runComprehensiveBufferize(Operation *op, + std::unique_ptr options); } // namespace comprehensive_bufferize } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H +#include + namespace mlir { class DialectRegistry; @@ -22,8 +24,9 @@ /// Bufferize the given module. This bufferizations performs a simple function /// call analysis to determine which function arguments are inplaceable. -LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, - const BufferizationOptions &options); +LogicalResult +runComprehensiveBufferize(ModuleOp moduleOp, + std::unique_ptr options); namespace std_ext { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -650,15 +650,14 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - Operation *op, const BufferizationOptions &options) { - BufferizationState state(op, options); - PostAnalysisStepList extraSteps; - return runComprehensiveBufferize(op, options, state, extraSteps); + Operation *op, std::unique_ptr options) { + BufferizationState state(op, *options); + return runComprehensiveBufferize(op, *options, state); } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( Operation *op, const BufferizationOptions &options, - BufferizationState &state, const PostAnalysisStepList &extraSteps) { + BufferizationState &state) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -672,23 +671,16 @@ return failure(); equivalenceAnalysis(op, aliasInfo, state); - auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { - for (const std::unique_ptr &step : steps) { - SmallVector newOps; - if (failed(step->run(op, state, aliasInfo, newOps))) - return failure(); - // Analyze ops that were created by the PostAnalysisStep. - if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) - return failure(); - equivalenceAnalysis(newOps, aliasInfo, state); - } - return success(); - }; - - if (failed(runPostAnalysisSteps(extraSteps))) - return failure(); - if (failed(runPostAnalysisSteps(options.postAnalysisSteps))) - return failure(); + for (const std::unique_ptr &step : + options.postAnalysisSteps) { + SmallVector newOps; + if (failed(step->run(op, state, aliasInfo, newOps))) + return failure(); + // Analyze ops that were created by the PostAnalysisStep. + if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) + return failure(); + equivalenceAnalysis(newOps, aliasInfo, state); + } // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -724,8 +724,8 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( - ModuleOp moduleOp, const BufferizationOptions &options) { - BufferizationState state(moduleOp, options); + ModuleOp moduleOp, std::unique_ptr options) { + BufferizationState state(moduleOp, *options); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -743,24 +743,23 @@ if (funcOp.body().empty()) continue; - // Register extra post analysis steps. These cannot be stored in `options` - // because `options` is immutable. - PostAnalysisStepList extraSteps; - extraSteps.emplace_back(std::make_unique()); + // Collect bbArg/return value information after the analysis. + options->postAnalysisSteps.emplace_back( + std::make_unique()); // Gather equivalence info for CallOps. equivalenceAnalysis(funcOp, aliasInfo, moduleState); // Analyze and bufferize funcOp. - if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps))) + if (failed(runComprehensiveBufferize(funcOp, *options, state))) return failure(); // Add annotations to function arguments. - if (options.testAnalysisOnly) + if (options->testAnalysisOnly) annotateOpsWithBufferizationMarkers(funcOp, state); } - if (options.testAnalysisOnly) + if (options->testAnalysisOnly) return success(); for (FuncOp funcOp : moduleState.orderedFuncOps) { @@ -769,7 +768,7 @@ if (failed(bufferizeFuncOpBoundary(funcOp, state))) return failure(); - if (!options.allowReturnMemref && + if (!options->allowReturnMemref && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); })) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -73,42 +73,42 @@ } void LinalgComprehensiveModuleBufferize::runOnOperation() { - BufferizationOptions options; + auto options = std::make_unique(); if (useAlloca) { - options.allocationFns->allocationFn = allocationFnUsingAlloca; - options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, - Value v) {}; + options->allocationFns->allocationFn = allocationFnUsingAlloca; + options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc, + Value v) {}; } // TODO: Change to memref::CopyOp (default memCpyFn). - options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, - Value to) { + options->allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, + Value to) { b.create(loc, from, to); }; - options.allowReturnMemref = allowReturnMemref; - options.allowUnknownOps = allowUnknownOps; - options.analysisFuzzerSeed = analysisFuzzerSeed; - options.testAnalysisOnly = testAnalysisOnly; - options.printConflicts = printConflicts; + options->allowReturnMemref = allowReturnMemref; + options->allowUnknownOps = allowUnknownOps; + options->analysisFuzzerSeed = analysisFuzzerSeed; + options->testAnalysisOnly = testAnalysisOnly; + options->printConflicts = printConflicts; // Enable InitTensorOp elimination. - options.addPostAnalysisStep< + options->addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); // TODO: Find a way to enable this step automatically when bufferizing tensor // dialect ops. - options.addPostAnalysisStep(); + options->addPostAnalysisStep(); if (!allowReturnMemref) - options.addPostAnalysisStep(); + options->addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); - if (failed(runComprehensiveBufferize(moduleOp, options))) { + if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) { signalPassFailure(); return; } - if (options.testAnalysisOnly) + if (testAnalysisOnly) return; OpPassManager cleanupPipeline("builtin.module"); diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -94,30 +94,30 @@ } // namespace void TestComprehensiveFunctionBufferize::runOnFunction() { - BufferizationOptions options; + auto options = std::make_unique(); // Enable InitTensorOp elimination. - options.addPostAnalysisStep< + options->addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); // TODO: Find a way to enable this step automatically when bufferizing // tensor dialect ops. - options.addPostAnalysisStep(); + options->addPostAnalysisStep(); if (!allowReturnMemref) - options.addPostAnalysisStep(); + options->addPostAnalysisStep(); - options.allowReturnMemref = allowReturnMemref; - options.allowUnknownOps = allowUnknownOps; - options.testAnalysisOnly = testAnalysisOnly; - options.analysisFuzzerSeed = analysisFuzzerSeed; + options->allowReturnMemref = allowReturnMemref; + options->allowUnknownOps = allowUnknownOps; + options->testAnalysisOnly = testAnalysisOnly; + options->analysisFuzzerSeed = analysisFuzzerSeed; if (dialectFilter.hasValue()) { - options.dialectFilter.emplace(); + options->dialectFilter.emplace(); for (const std::string &dialectNamespace : dialectFilter) - options.dialectFilter->insert(dialectNamespace); + options->dialectFilter->insert(dialectNamespace); } Operation *op = getFunction().getOperation(); - if (failed(runComprehensiveBufferize(op, options))) + if (failed(runComprehensiveBufferize(op, std::move(options)))) return; OpPassManager cleanupPipeline("builtin.func");