diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -45,32 +45,93 @@ using MemCpyFn = std::function; + /// An op filter entry. Filters can be used to specify which ops should be + /// processed by the bufferization. + struct OpFilterEntry { + /// If the filter function evaluates to `true`, the filter matches. + using FilterFn = std::function; + + /// Filter type: A filter can either be a DENY filter or an ALLOW filter. + enum FilterType : int8_t { DENY = 0, ALLOW = 1 }; + + FilterFn fn; + FilterType type; + }; + BufferizationOptions(); - /// Return `true` if the op is allowed to be bufferized. + /// Return whether the op should be bufferized or not. + /// + /// If no filter is specified (`hasFilter` = false), every op will be + /// bufferized. Otherwise, an op is bufferized if: + /// + /// - At least one ALLOW filter says `true`. + /// - And, no DENY filter says `true`. bool isOpAllowed(Operation *op) const { if (!hasFilter) return true; - return dialectFilter.contains(op->getDialect()->getNamespace()) || - operationFilter.contains(op->getName().getStringRef()); + bool isAllowed = false; + for (const OpFilterEntry &entry : opFilter) { + bool filterResult = entry.fn(op); + switch (entry.type) { + case OpFilterEntry::ALLOW: + isAllowed |= filterResult; + break; + case OpFilterEntry::DENY: + if (filterResult) + // DENY filter matches. This op is no allowed. (Even if other ALLOW + // filters may match.) + return false; + }; + } + return isAllowed; } /// Allow the given dialects and activate the filter (`hasFilter`). + /// + /// This function adds one or multiple ALLOW filters. template - void addToDialectFilter() { - // The following expands a call to addToDialectFilterImpl for each dialect + void allowDialectInFilter() { + // The following expands a call to allowDialectInFilterImpl for each dialect // in 'DialectTs'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. (void)std::initializer_list{ - 0, (addToDialectFilterImpl(), 0)...}; + 0, (allowDialectInFilterImpl(), 0)...}; + } + + /// Allow the given dialect and activate the filter (`hasFilter`). + /// + /// This function adds an ALLOW filter. + void allowDialectInFilter(StringRef dialectNamespace) { + hasFilter = true; + OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + return op->getDialect()->getNamespace() == dialectNamespace; + }; + opFilter.push_back( + OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW}); } /// Allow the given ops and activate the filter (`hasFilter`). - template void addToOperationFilter() { + /// + /// This function adds one or multiple ALLOW filters. + template + void allowOperationInFilter() { // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{0, - (addToOperationFilterImpl(), 0)...}; + (void)std::initializer_list{ + 0, (allowOperationInFilterImpl(), 0)...}; + } + + /// Allow the given op and activate the filter (`hasFilter`). + /// + /// This function adds an ALLOW filter. + void allowOperationInFilter(StringRef opName) { + hasFilter = true; + OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + return op->getName().getStringRef() == opName; + }; + opFilter.push_back( + OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW}); } /// Try to cast the given op to BufferizableOpInterface if the op is allow @@ -118,33 +179,26 @@ /// Buffer alignment for new memory allocations. unsigned int bufferAlignment = 128; - /// If set to `true`, only ops that belong to a filtered dialect - /// (`dialectFilter`) and filtered ops (`operationFilter`) are processed. All - /// other ops are ignored. If set to `false`, all ops are bufferized (as long - /// as they implement BufferizableOpInterface). - /// - /// If a filter is specified, `allowUnknownOps` should be enabled. Otherwise, - /// bufferization would fail when encountering a non-filtered op. + /// If set to `false`, all ops are bufferized (as long as they implement + /// BufferizableOpInterface). Otherwise, only filtered ops are bufferized. bool hasFilter = false; - /// A set of allowed dialects. - DenseSet dialectFilter; - - /// A set of allowed ops. - DenseSet operationFilter; + /// A list of op filters that determine whether an op should be processed or + /// ignored by the bufferization. If `hasFilter`, only ops that are not + /// DENY-filtered and have at least one matching ALLOW filter are processed. + SmallVector opFilter; private: /// Allow a dialect. template - void addToDialectFilterImpl() { - hasFilter = true; - dialectFilter.insert(DialectT::getDialectNamespace()); + void allowDialectInFilterImpl() { + allowDialectInFilter(DialectT::getDialectNamespace()); } /// Allow an op. - template void addToOperationFilterImpl() { - hasFilter = true; - operationFilter.insert(OpTy::getOperationName()); + template + void allowOperationInFilterImpl() { + allowOperationInFilter(OpTy::getOperationName()); } }; diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -31,9 +31,9 @@ void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); if (constantOpOnly) { - options.addToOperationFilter(); + options.allowOperationInFilter(); } else { - options.addToDialectFilter(); + options.allowDialectInFilter(); } options.bufferAlignment = alignment; diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -31,7 +31,7 @@ struct TensorBufferizePass : public TensorBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.addToDialectFilter(); + options.allowDialectInFilter(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -116,7 +116,7 @@ if (dialectFilter.hasValue()) { options->hasFilter = true; for (const std::string &dialectNamespace : dialectFilter) - options->dialectFilter.insert(dialectNamespace); + options->allowDialectInFilter(dialectNamespace); } Operation *op = getOperation();