diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "llvm/ADT/EquivalenceClasses.h" +#include namespace mlir { namespace bufferization { @@ -33,6 +34,10 @@ /// The heuristic controls the order in which ops are traversed during the /// analysis. AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp; + + /// Specify the functions that should not be analyzed. copyBeforeWrite will be + /// set to true when bufferizing them. + llvm::ArrayRef noAnalysisFuncFilter; }; /// The BufferizationAliasInfo class maintains a list of buffer aliases and diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" namespace mlir { struct LogicalResult; @@ -31,13 +30,12 @@ /// inserted except two cases: /// - `options.copyBeforeWrite` is set, in which case buffers are copied before /// every write. -/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true -/// for some FuncOps. These FuncOps were not analyzed. Buffer copies will be -/// inserted only to these FuncOps. -LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics = nullptr, - OpFilter::Entry::FilterFn analysisFilterFn = nullptr); +/// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter` +/// is not empty. The FuncOps it contains were not analyzed. Buffer copies +/// will be inserted only to these FuncOps. +LogicalResult bufferizeModuleOp(ModuleOp moduleOp, + const OneShotBufferizationOptions &options, + BufferizationStatistics *statistics = nullptr); /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. void removeBufferizationAttributesInModule(ModuleOp moduleOp); @@ -49,8 +47,7 @@ LogicalResult runOneShotModuleBufferize( ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, - BufferizationStatistics *statistics = nullptr, - OpFilter::Entry::FilterFn analysisFilterFn = nullptr); + BufferizationStatistics *statistics = nullptr); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -215,6 +215,7 @@ opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; + opt.noAnalysisFuncFilter = noAnalysisFuncFilter; // Configure type converter. LayoutMapOption unknownTypeConversionOption = @@ -249,25 +250,12 @@ BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { - OpFilter::Entry::FilterFn analysisFilterFn = nullptr; - // FuncOps whose names are specified in noAnalysisFuncFilter will not be - // analyzed. Ops in these FuncOps will not be analyzed as well. - if (this->noAnalysisFuncFilter.hasValue()) - analysisFilterFn = [=](Operation *op) { - auto func = dyn_cast(op); - if (!func) - func = op->getParentOfType(); - if (func) - return llvm::is_contained(noAnalysisFuncFilter, func.getSymName()); - return false; - }; - if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics, - analysisFilterFn))) { + if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; } } else { - assert(!this->noAnalysisFuncFilter.hasValue() && + assert(opt.noAnalysisFuncFilter.empty() && "invalid combination of bufferization flags"); if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -413,8 +413,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics, - OpFilter::Entry::FilterFn analysisFilterFn) { + BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); @@ -432,8 +431,9 @@ for (func::FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - bool copyBeforeWrite = options.copyBeforeWrite || - (analysisFilterFn && analysisFilterFn(funcOp)); + bool copyBeforeWrite = + options.copyBeforeWrite || + llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName()); if (failed(bufferizeOp(funcOp, options, copyBeforeWrite, /*opFilter=*/nullptr, statistics))) return failure(); @@ -451,17 +451,27 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics, - OpFilter::Entry::FilterFn analysisFilterFn) { + BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && "invalid combination of bufferization flags"); if (!options.copyBeforeWrite) { - if (!analysisFilterFn) { + if (options.noAnalysisFuncFilter.empty()) { if (failed(insertTensorCopies(moduleOp, options, statistics))) return failure(); } else { + // FuncOps whose names are specified in options.noAnalysisFuncFilter will + // not be analyzed. Ops in these FuncOps will not be analyzed as well. + OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { + auto func = dyn_cast(op); + if (!func) + func = op->getParentOfType(); + if (func) + return llvm::is_contained(options.noAnalysisFuncFilter, + func.getSymName()); + return false; + }; OneShotBufferizationOptions updatedOptions(options); updatedOptions.opFilter.denyOperation(analysisFilterFn); if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics))) @@ -470,9 +480,7 @@ } if (options.testAnalysisOnly) return success(); - - if (failed( - bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn))) + if (failed(bufferizeModuleOp(moduleOp, options, statistics))) return failure(); return success(); }