diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -23,28 +23,11 @@ class BufferizableOpInterface; struct DialectAnalysisState; -/// Options for BufferizableOpInterface-based bufferization. -struct BufferizationOptions { - /// Allocator function: Generate a memref allocation with the given type, - /// dynamic extents and alignment. - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>; - /// Deallocator function: Deallocate a buffer that was allocated with - /// AllocatorFn. - using DeallocationFn = - std::function; - /// Memcpy function: Generate a memcpy between two buffers. - using MemCpyFn = - std::function; - /// Initializer function for analysis state. - using AnalysisStateInitFn = std::function; - /// Initializer function for dialect-specific analysis state. - using DialectStateInitFn = - std::function()>; - +class OpFilter { +public: /// An op filter entry. Filters can be used to specify which ops should be /// processed by the bufferization. - struct OpFilterEntry { + struct Entry { /// If the filter function evaluates to `true`, the filter matches. using FilterFn = std::function; @@ -55,116 +38,156 @@ FilterType type; }; - enum class LayoutMapOption : int8_t { - InferLayoutMap = 0, - IdentityLayoutMap = 1, - FullyDynamicLayoutMap = 2 - }; - - BufferizationOptions(); - - /// Return `true` if the filter has at least one ALLOW rule. - bool filterHasAllowRule() const { - for (const OpFilterEntry &e : opFilter) - if (e.type == OpFilterEntry::FilterType::ALLOW) - return true; - return false; - } - - /// Return whether the op should be bufferized or not. + /// Return whether the op is allowed or not. /// - /// If the filter does not have an ALLOW rule, ops are bufferized by default, + /// If the filter does not have an ALLOW rule, ops are allowed by default, /// unless they are explicitly marked as DENY. If the filter has at least one - /// ALLOW rule, ops are ignored by default and only bufferized if they match + /// ALLOW rule, ops are denied by default and only allowed if they match /// an ALLOW rule and no DENY rule. bool isOpAllowed(Operation *op) const; - /// Allow the given dialects in the filter. + /// Allow the given dialects. /// - /// This function adds one or multiple ALLOW filters. - template - void allowDialectInFilter() { - // The following expands a call to allowDialectInFilterImpl for each dialect + /// This function adds one or multiple ALLOW entries. + template void allowDialect() { + // The following expands a call to allowDialectImpl for each dialect // in 'DialectTs'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, (allowDialectInFilterImpl(), 0)...}; + (void)std::initializer_list{0, (allowDialectImpl(), 0)...}; } - /// Deny the given dialects in the filter. + /// Deny the given dialects. /// - /// This function adds one or multiple DENY filters. - template void denyDialectInFilter() { + /// This function adds one or multiple DENY entries. + template void denyDialect() { // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, (denyDialectInFilterImpl(), 0)...}; + (void)std::initializer_list{0, (denyDialectImpl(), 0)...}; } - /// Allow the given dialect in the filter. + /// Allow the given dialect. /// - /// This function adds an ALLOW filter. - void allowDialectInFilter(StringRef dialectNamespace) { - OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + /// This function adds an ALLOW entry. + void allowDialect(StringRef dialectNamespace) { + Entry::FilterFn filterFn = [=](Operation *op) { return op->getDialect()->getNamespace() == dialectNamespace; }; - opFilter.push_back( - OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW}); + entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW}); } - /// Allow the given ops in the filter. + /// Allow the given ops. /// - /// This function adds one or multiple ALLOW filters. - template - void allowOperationInFilter() { + /// This function adds one or multiple ALLOW entries. + template void allowOperation() { // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, (allowOperationInFilterImpl(), 0)...}; + (void)std::initializer_list{0, (allowOperationImpl(), 0)...}; } - /// Deny the given ops in the filter. + /// Deny the given ops. /// - /// This function adds one or multiple DENY filters. - template void denyOperationInFilter() { + /// This function adds one or multiple DENY entries. + template void denyOperation() { // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, (denyOperationInFilterImpl(), 0)...}; + (void)std::initializer_list{0, (denyOperationImpl(), 0)...}; } - /// Allow the given op in the filter. + /// Allow the given op. /// - /// This function adds an ALLOW filter. - void allowOperationInFilter(StringRef opName) { - OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + /// This function adds an ALLOW entry. + void allowOperation(StringRef opName) { + Entry::FilterFn filterFn = [=](Operation *op) { return op->getName().getStringRef() == opName; }; - allowOperationInFilter(filterFn); + allowOperation(filterFn); } - /// Deny the given op in the filter. + /// Deny the given op. /// - /// This function adds a DENY filter. - void denyOperationInFilter(StringRef opName) { - OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + /// This function adds a DENY entry. + void denyOperation(StringRef opName) { + Entry::FilterFn filterFn = [=](Operation *op) { return op->getName().getStringRef() == opName; }; - denyOperationInFilter(filterFn); + denyOperation(filterFn); } - /// Allow ops that are matched by `fn` in the filter. + /// Allow ops that are matched by `fn`. /// - /// This function adds an ALLOW filter. - void allowOperationInFilter(OpFilterEntry::FilterFn fn) { - opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW}); + /// This function adds an ALLOW entry. + void allowOperation(Entry::FilterFn fn) { + entries.push_back(Entry{fn, Entry::FilterType::ALLOW}); } - /// Deny ops that are matched by `fn` in the filter. + /// Deny ops that are matched by `fn`. /// - /// This function adds a DENY filter. - void denyOperationInFilter(OpFilterEntry::FilterFn fn) { - opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY}); + /// This function adds a DENY entry. + void denyOperation(Entry::FilterFn fn) { + entries.push_back(Entry{fn, Entry::FilterType::DENY}); + } + +private: + /// Return `true` if the filter has at least one ALLOW rule. + bool hasAllowRule() const { + for (const Entry &e : entries) + if (e.type == Entry::FilterType::ALLOW) + return true; + return false; + } + + /// Allow a dialect. + template void allowDialectImpl() { + allowDialect(DialectT::getDialectNamespace()); + } + + /// Deny a dialect. + template void denyDialectImpl() { + denyDialect(DialectT::getDialectNamespace()); + } + + /// Allow an op. + template void allowOperationImpl() { + allowOperation(OpTy::getOperationName()); } + /// Deny an op. + template void denyOperationImpl() { + denyOperation(OpTy::getOperationName()); + } + + /// A list of filter entries that determine whether an op should be allowed or + /// denied. If the filter has an ALLOW rule, only ops that are allowed and not + /// denied are allowed. If the filter does not have an ALLOW rule, only ops + /// that are not denied are allowed. + SmallVector entries; +}; + +/// Options for BufferizableOpInterface-based bufferization. +struct BufferizationOptions { + /// Allocator function: Generate a memref allocation with the given type, + /// dynamic extents and alignment. + using AllocationFn = std::function( + OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>; + /// Deallocator function: Deallocate a buffer that was allocated with + /// AllocatorFn. + using DeallocationFn = + std::function; + /// Memcpy function: Generate a memcpy between two buffers. + using MemCpyFn = + std::function; + /// Initializer function for analysis state. + using AnalysisStateInitFn = std::function; + /// Initializer function for dialect-specific analysis state. + using DialectStateInitFn = + std::function()>; + + enum class LayoutMapOption : int8_t { + InferLayoutMap = 0, + IdentityLayoutMap = 1, + FullyDynamicLayoutMap = 2 + }; + + BufferizationOptions(); + /// Try to cast the given op to BufferizableOpInterface if the op is allow /// listed. BufferizableOpInterface dynCastBufferizableOp(Operation *op) const; @@ -173,6 +196,13 @@ /// listed. BufferizableOpInterface dynCastBufferizableOp(Value value) const; + /// A filter that specifies which ops should be bufferized and which ops + /// should be ignored. + OpFilter opFilter; + + /// Return `true` if the given op should be bufferized. + bool isOpAllowed(Operation *op) const; + /// Helper functions for allocation, deallocation, memory copying. Optional allocationFn; Optional deallocationFn; @@ -276,12 +306,6 @@ /// Buffer alignment for new memory allocations. unsigned int bufferAlignment = 128; - /// A list of op filters that determine whether an op should be processed or - /// ignored by the bufferization. If the filter has an ALLOW rule, only ops - /// that are allowed and not denied are bufferized. If the filter does not - /// have an ALLOW rule, only ops that are not denied are bufferized. - SmallVector opFilter; - /// Initializer functions for analysis state. These can be used to /// initialize dialect-specific analysis state. SmallVector stateInitializers; @@ -289,29 +313,6 @@ /// Add a analysis state initializer that initializes the specified /// dialect-specific analysis state. void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn); - -private: - /// Allow a dialect. - template - void allowDialectInFilterImpl() { - allowDialectInFilter(DialectT::getDialectNamespace()); - } - - /// Deny a dialect. - template void denyDialectInFilterImpl() { - denyDialectInFilter(DialectT::getDialectNamespace()); - } - - /// Allow an op. - template - void allowOperationInFilterImpl() { - allowOperationInFilter(OpTy::getOperationName()); - } - - /// Deny an op. - template void denyOperationInFilterImpl() { - denyOperationInFilter(OpTy::getOperationName()); - } }; /// Specify fine-grain relationship between buffers to enable more analysis. diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -32,9 +32,9 @@ void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); if (constantOpOnly) { - options.allowOperationInFilter(); + options.opFilter.allowOperation(); } else { - options.allowDialectInFilter(); + options.opFilter.allowDialect(); } options.bufferAlignment = alignment; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -45,28 +45,19 @@ static const char *kSkipDeallocAttr = "bufferization.skip_dealloc"; //===----------------------------------------------------------------------===// -// BufferizationOptions +// OpFilter //===----------------------------------------------------------------------===// -// Default constructor for BufferizationOptions. -BufferizationOptions::BufferizationOptions() = default; - -bool BufferizationOptions::isOpAllowed(Operation *op) const { - // Special case: If function boundary bufferization is deactivated, do not - // allow ops that belong to the `func` dialect. - bool isFuncBoundaryOp = isa_and_nonnull(op->getDialect()); - if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) - return false; - +bool OpFilter::isOpAllowed(Operation *op) const { // All other ops: Allow/disallow according to filter. - bool isAllowed = !filterHasAllowRule(); - for (const OpFilterEntry &entry : opFilter) { + bool isAllowed = !hasAllowRule(); + for (const Entry &entry : entries) { bool filterResult = entry.fn(op); switch (entry.type) { - case OpFilterEntry::ALLOW: + case Entry::ALLOW: isAllowed |= filterResult; break; - case OpFilterEntry::DENY: + case Entry::DENY: if (filterResult) // DENY filter matches. This op is no allowed. (Even if other ALLOW // filters may match.) @@ -76,6 +67,23 @@ return isAllowed; } +//===----------------------------------------------------------------------===// +// BufferizationOptions +//===----------------------------------------------------------------------===// + +// Default constructor for BufferizationOptions. +BufferizationOptions::BufferizationOptions() = default; + +bool BufferizationOptions::isOpAllowed(Operation *op) const { + // Special case: If function boundary bufferization is deactivated, do not + // allow ops that belong to the `func` dialect. + bool isFuncBoundaryOp = isa_and_nonnull(op->getDialect()); + if (!bufferizeFunctionBoundaries && isFuncBoundaryOp) + return false; + + return opFilter.isOpAllowed(op); +} + BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Operation *op) const { auto bufferizableOp = dyn_cast(op); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -194,7 +194,7 @@ opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams; opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion); - BufferizationOptions::OpFilterEntry::FilterFn filterFn = + OpFilter::Entry::FilterFn filterFn = [&](Operation *op) { // Filter may be specified via options. if (this->dialectFilter.hasValue()) @@ -204,7 +204,7 @@ // No filter specified: All other ops are allowed. return true; }; - opt.allowOperationInFilter(filterFn); + opt.opFilter.allowOperation(filterFn); } else { opt = *options; } @@ -242,7 +242,7 @@ : public BufferizationBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.allowDialectInFilter(); + options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -28,7 +28,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.allowDialectInFilter(); + options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -23,7 +23,7 @@ struct ShapeBufferizePass : public ShapeBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.allowDialectInFilter(); + options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -30,7 +30,7 @@ struct TensorBufferizePass : public TensorBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.allowDialectInFilter(); + options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp @@ -27,7 +27,7 @@ struct VectorBufferizePass : public VectorBufferizeBase { void runOnOperation() override { BufferizationOptions options = getPartialBufferizationOptions(); - options.allowDialectInFilter(); + options.opFilter.allowDialect(); if (failed(bufferizeOp(getOperation(), options))) signalPassFailure();