diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -28,8 +28,8 @@ namespace bufferization { class BufferizableOpInterface; -struct BufferizationOptions; class BufferizationState; +struct DialectBufferizationState; /// Options for ComprehensiveBufferize. struct BufferizationOptions { @@ -44,6 +44,11 @@ /// Memcpy function: Generate a memcpy between two buffers. using MemCpyFn = std::function; + /// Initializer function for bufferization state. + using BufferizationStateInitFn = std::function; + /// Initializer function for dialect-specific bufferization state. + using DialectStateInitFn = + std::function()>; /// An op filter entry. Filters can be used to specify which ops should be /// processed by the bufferization. @@ -228,6 +233,14 @@ /// DENY-filtered and have at least one matching ALLOW filter are processed. SmallVector opFilter; + /// Initializer functions for bufferization state. These can be used to + /// initialize dialect-specific bufferization state. + SmallVector stateInitializers; + + /// Add a bufferization state initializer that initializes the specified + /// dialect-specific bufferization state. + void addDialectStateInitializer(StringRef name, DialectStateInitFn fn); + private: /// Allow a dialect. template @@ -362,6 +375,12 @@ return static_cast(*dialectState[name]); } + void insertDialectState(StringRef name, + std::unique_ptr state) { + assert(!dialectState.count(name) && "dialect state already initialized"); + dialectState[name] = std::move(state); + } + /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return options; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -64,6 +64,12 @@ return nullptr; } +void BufferizationOptions::addDialectStateInitializer(StringRef name, + DialectStateInitFn fn) { + stateInitializers.push_back( + [=](BufferizationState &state) { state.insertDialectState(name, fn()); }); +} + //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -200,7 +206,11 @@ } BufferizationState::BufferizationState(const BufferizationOptions &options) - : options(options) {} + : options(options) { + for (const BufferizationOptions::BufferizationStateInitFn &fn : + options.stateInitializers) + fn(*this); +} // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {