diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -78,23 +78,7 @@ /// unless they are explicitly marked as DENY. If the filter has at least one /// ALLOW rule, ops are ignored by default and only bufferized if they match /// an ALLOW rule and no DENY rule. - bool isOpAllowed(Operation *op) const { - bool isAllowed = !filterHasAllowRule(); - for (const OpFilterEntry &entry : opFilter) { - bool filterResult = entry.fn(op); - switch (entry.type) { - case OpFilterEntry::ALLOW: - isAllowed |= filterResult; - break; - case OpFilterEntry::DENY: - if (filterResult) - // DENY filter matches. This op is no allowed. (Even if other ALLOW - // filters may match.) - return false; - }; - } - return isAllowed; - } + bool isOpAllowed(Operation *op) const; /// Allow the given dialects in the filter. /// @@ -182,6 +166,10 @@ /// the boundaries. bool allowUnknownOps = false; + /// Specifies whether function boundaries (ops in the func dialect) should be + /// bufferized or not. + bool bufferizeFunctionBoundaries = false; + /// Specifies whether dealloc ops should be generated along with alloc ops. If /// not, new memory allocations will leak. bool createDeallocs = true; @@ -356,6 +344,12 @@ /// any given tensor. virtual bool isTensorYielded(Value tensor) const = 0; + /// Return `true` if the given dialect state exists. + bool hasDialectState(StringRef name) const { + auto it = dialectState.find(name); + return it != dialectState.end(); + } + /// Return dialect-specific bufferization state. template Optional getDialectState(StringRef name) const { @@ -369,7 +363,7 @@ template StateT &getOrCreateDialectState(StringRef name) { // Create state if it does not exist yet. - if (!dialectState.count(name)) + if (!hasDialectState(name)) dialectState[name] = std::make_unique(); return static_cast(*dialectState[name]); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -51,6 +51,31 @@ // Default constructor for BufferizationOptions. BufferizationOptions::BufferizationOptions() = default; +bool BufferizationOptions::isOpAllowed(Operation *op) const { + // Special case: If function boundary bufferization is deactivated, do not + // allow ops that belong to the `func` dialect. + bool isFuncBoundaryOp = isa_and_nonnull(op->getDialect()); + if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) + return false; + + // All other ops: Allow/disallow according to filter. + bool isAllowed = !filterHasAllowRule(); + for (const OpFilterEntry &entry : opFilter) { + bool filterResult = entry.fn(op); + switch (entry.type) { + case OpFilterEntry::ALLOW: + isAllowed |= filterResult; + break; + case OpFilterEntry::DENY: + if (filterResult) + // DENY filter matches. This op is no allowed. (Even if other ALLOW + // filters may match.) + return false; + }; + } + return isAllowed; +} + BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Operation *op) const { if (isOpAllowed(op)) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -175,15 +175,10 @@ opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; + opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; BufferizationOptions::OpFilterEntry::FilterFn filterFn = [&](Operation *op) { - // Disallow non-func dialect ops. I.e., no ops related to function - // calls. (Unless explicitly activated.) - bool isFuncBoundaryOp = - isa_and_nonnull(op->getDialect()); - if (!this->bufferizeFunctionBoundaries && isFuncBoundaryOp) - return false; // Filter may be specified via options. if (this->dialectFilter.hasValue()) return llvm::find(this->dialectFilter, @@ -198,7 +193,7 @@ } ModuleOp moduleOp = getOperation(); - if (bufferizeFunctionBoundaries) { + if (opt.bufferizeFunctionBoundaries) { if (failed(runOneShotModuleBufferize(moduleOp, opt))) { signalPassFailure(); return; @@ -284,6 +279,12 @@ LogicalResult bufferization::bufferizeOp(Operation *op, const AnalysisState &analysisState) { + // Catch incorrect API usage. + assert((analysisState.hasDialectState( + func::FuncDialect::getDialectNamespace()) || + !analysisState.getOptions().bufferizeFunctionBoundaries) && + "must use ModuleBufferize to bufferize function boundaries"); + BufferizationState bufferizationState(analysisState); if (failed(bufferizeOp(op, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -46,6 +46,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Dominance.h" @@ -864,6 +865,11 @@ const auto &options = static_cast(state.getOptions()); + // Catch incorrect API usage. + assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) || + !options.bufferizeFunctionBoundaries) && + "must use ModuleBufferize to bufferize function boundaries"); + if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -417,6 +417,8 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, OneShotBufferizationOptions options) { + assert(options.bufferizeFunctionBoundaries && + "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); OneShotAnalysisState analysisState(moduleOp, options); BufferizationState bufferizationState(analysisState); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -99,6 +99,7 @@ opt.printConflicts = printConflicts; opt.testAnalysisOnly = testAnalysisOnly; opt.alwaysAliasingWithDest = alwaysAliasingWithDest; + opt.bufferizeFunctionBoundaries = true; if (initTensorElimination) { opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep); }