diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" namespace mlir { struct LogicalResult; @@ -27,11 +28,16 @@ /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// /// Note: This function does not run One-Shot Analysis. No buffer copies are -/// inserted unless `options.copyBeforeWrite` is set, in which case buffers are -/// copied before every write. -LogicalResult bufferizeModuleOp(ModuleOp moduleOp, - const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics = nullptr); +/// inserted except two cases: +/// - `options.copyBeforeWrite` is set, in which case buffers are copied before +/// every write. +/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true +/// for some FuncOps. These FuncOps were not analyzed. Buffer copies will be +/// inserted only to these FuncOps. +LogicalResult +bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, + BufferizationStatistics *statistics = nullptr, + OpFilter::Entry::FilterFn analysisFilterFn = nullptr); /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. void removeBufferizationAttributesInModule(ModuleOp moduleOp); @@ -43,7 +49,8 @@ LogicalResult runOneShotModuleBufferize( ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, - BufferizationStatistics *statistics = nullptr); + BufferizationStatistics *statistics = nullptr, + OpFilter::Entry::FilterFn analysisFilterFn = nullptr); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -297,6 +297,9 @@ "core bufferization passes.">, ListOption<"dialectFilter", "dialect-filter", "std::string", "Restrict bufferization to ops from these dialects.">, + ListOption<"noAnalysisFuncFilter", "no-analysis-func-filter", "std::string", + "Skip analysis of functions with these symbol names." + "Set copyBeforeWrite to true when bufferizing them.">, Option<"functionBoundaryTypeConversion", "function-boundary-type-conversion", "std::string", /*default=*/"\"infer-layout-map\"", diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -249,11 +249,26 @@ BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { - if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { + OpFilter::Entry::FilterFn analysisFilterFn = nullptr; + // FuncOps whose names are specified in noAnalysisFuncFilter will not be + // analyzed. Ops in these FuncOps will not be analyzed as well. + if (this->noAnalysisFuncFilter.hasValue()) + analysisFilterFn = [=](Operation *op) { + auto func = dyn_cast(op); + if (!func) + func = op->getParentOfType(); + if (func) + return llvm::is_contained(noAnalysisFuncFilter, func.getSymName()); + return false; + }; + if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics, + analysisFilterFn))) { signalPassFailure(); return; } } else { + assert(!this->noAnalysisFuncFilter.hasValue() && + "invalid combination of bufferization flags"); if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { signalPassFailure(); return; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -378,6 +378,9 @@ // Analyze ops. for (func::FuncOp funcOp : orderedFuncOps) { + if (!state.getOptions().isOpAllowed(funcOp)) + continue; + // Now analyzing function. funcState.startFunctionAnalysis(funcOp); @@ -410,7 +413,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics) { + BufferizationStatistics *statistics, + OpFilter::Entry::FilterFn analysisFilterFn) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); @@ -428,7 +432,9 @@ for (func::FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeOp(funcOp, options, options.copyBeforeWrite, + bool copyBeforeWrite = options.copyBeforeWrite || + (analysisFilterFn && analysisFilterFn(funcOp)); + if (failed(bufferizeOp(funcOp, options, copyBeforeWrite, /*opFilter=*/nullptr, statistics))) return failure(); // Change buffer return types to more precise layout maps. @@ -445,18 +451,28 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics) { + BufferizationStatistics *statistics, + OpFilter::Entry::FilterFn analysisFilterFn) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && "invalid combination of bufferization flags"); if (!options.copyBeforeWrite) { - if (failed(insertTensorCopies(moduleOp, options, statistics))) - return failure(); + if (!analysisFilterFn) { + if (failed(insertTensorCopies(moduleOp, options, statistics))) + return failure(); + } else { + OneShotBufferizationOptions updatedOptions(options); + updatedOptions.opFilter.denyOperation(analysisFilterFn); + if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics))) + return failure(); + } } if (options.testAnalysisOnly) return success(); - if (failed(bufferizeModuleOp(moduleOp, options, statistics))) + + if (failed( + bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn))) return failure(); return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 no-analysis-func-filter=contains_to_memref_op" -drop-equivalent-buffer-results --split-input-file | FileCheck %s + +// ToMemref ops do not pass analysis step. CopyBeforeWrite will be true only for the +// FuncOp "contains_to_memref_op" since it is specified in no-analysis-func-filter. + +module { + // CHECK-LABEL: func.func @foo( + // CHECK-SAME: %[[arg0:.*]]: memref>) { + func.func @foo(%arg0: tensor) -> tensor { + // CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index + %cst = arith.constant 1.000000e+00 : f32 + + // CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32 + %c0 = arith.constant 0 : index + + // CHECK-NEXT: memref.store %[[c1]], %[[arg0]]{{\[}}%[[c0]]] : memref> + %inserted = tensor.insert %cst into %arg0[%c0] : tensor + + return %inserted : tensor + } + + // CHECK-LABEL: func.func @contains_to_memref_op( + // CHECK-SAME: %[[arg0:.*]]: memref>, + // CHECK-SAME: %[[arg1:.*]]: index) -> vector<5xf32> { + func.func @contains_to_memref_op(%arg0: tensor {bufferization.writable = true}, %arg1: index) -> vector<5xf32> { + + %0 = bufferization.to_memref %arg0 : memref + + // CHECK: %[[c0:.*]] = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c0]] : memref> + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref + // CHECK: memref.copy %[[arg0]], %[[alloc]] : memref> to memref + // CHECK: vector.transfer_read + %1 = vector.transfer_read %0[%arg1], %cst : memref, vector<5xf32> + return %1 : vector<5xf32> + } +} \ No newline at end of file