diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -21,7 +21,6 @@ class AnalysisState; class BufferizableOpInterface; -struct DialectAnalysisState; class OpFilter { public: @@ -179,9 +178,6 @@ std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Initializer function for dialect-specific analysis state. - using DialectStateInitFn = - std::function()>; /// Tensor -> MemRef type converter. /// Parameters: Value, memory space, bufferization options using UnknownTypeConverterFn = std::function stateInitializers; - - /// Add a analysis state initializer that initializes the specified - /// dialect-specific analysis state. - void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn); }; /// Specify fine-grain relationship between buffers to enable more analysis. @@ -322,18 +314,6 @@ /// Return `true` if the given value is a BlockArgument of a func::FuncOp. bool isFunctionArgument(Value value); -/// Dialect-specific analysis state. Analysis/bufferization information -/// that is specific to ops from a certain dialect can be stored in derived -/// variants of this struct. -struct DialectAnalysisState { - DialectAnalysisState() = default; - - virtual ~DialectAnalysisState() = default; - - // Copying state is forbidden. Always pass as reference. - DialectAnalysisState(const DialectAnalysisState &) = delete; -}; - /// AnalysisState provides a variety of helper functions for dealing with /// tensor values. class AnalysisState { @@ -426,52 +406,27 @@ /// any given tensor. virtual bool isTensorYielded(Value tensor) const; - /// Return `true` if the given dialect state exists. - bool hasDialectState(StringRef name) const { - auto it = dialectState.find(name); - return it != dialectState.end(); - } - - /// Return dialect-specific bufferization state. - template - Optional getDialectState(StringRef name) const { - auto it = dialectState.find(name); - if (it == dialectState.end()) - return None; - return static_cast(it->getSecond().get()); - } - - /// Return dialect-specific analysis state or create one if none exists. - template - StateT &getOrCreateDialectState(StringRef name) { - // Create state if it does not exist yet. - if (!hasDialectState(name)) - dialectState[name] = std::make_unique(); - return static_cast(*dialectState[name]); - } - - void insertDialectState(StringRef name, - std::unique_ptr state) { - assert(!dialectState.count(name) && "dialect state already initialized"); - dialectState[name] = std::move(state); - } - /// Return a reference to the BufferizationOptions. const BufferizationOptions &getOptions() const { return options; } - explicit AnalysisState(const BufferizationOptions &options); + AnalysisState(const BufferizationOptions &options, + TypeID type = TypeID::get()); // AnalysisState should be passed as a reference. AnalysisState(const AnalysisState &) = delete; virtual ~AnalysisState() = default; -private: - /// Dialect-specific analysis state. - DenseMap> dialectState; + static bool classof(const AnalysisState *base) { return true; } + + TypeID getType() const { return type; } +private: /// A reference to current bufferization options. const BufferizationOptions &options; + + /// The type of analysis. + TypeID type; }; /// Create an AllocTensorOp for the given shaped value (memref or tensor). diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -10,6 +10,8 @@ #define MLIR_BUFFERIZATION_TRANSFORMS_FUNCBUFFERIZABLEOPINTERFACEIMPL_H #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" namespace mlir { class DialectRegistry; @@ -27,7 +29,10 @@ /// Extra analysis state that is required for bufferization of function /// boundaries. -struct FuncAnalysisState : public DialectAnalysisState { +struct FuncAnalysisState : public OneShotAnalysisState::Extension { + FuncAnalysisState(OneShotAnalysisState &state) + : OneShotAnalysisState::Extension(state) {} + // Note: Function arguments and/or function return values may disappear during // bufferization. Functions and their CallOps are analyzed and bufferized // separately. To ensure that a CallOp analysis/bufferization can access an diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -132,6 +132,10 @@ ~OneShotAnalysisState() override = default; + static bool classof(const AnalysisState *base) { + return base->getType() == TypeID::get(); + } + /// Return a reference to the BufferizationAliasInfo. BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } @@ -166,6 +170,92 @@ /// Return true if the buffer of the given tensor value is writable. bool isWritable(Value value) const; + /// Base class for OneShotAnalysisState extensions that allow + /// OneShotAnalysisState to contain user-specified information in the state + /// object. Clients are expected to derive this class, add the desired fields, + /// and make the derived class compatible with the MLIR TypeID mechanism. + /// + /// ```mlir + /// class MyExtension final : public OneShotAnalysisState::Extension { + /// public: + /// MyExtension(OneShotAnalysisState &state, int myData) + /// : Extension(state) {...} + /// private: + /// int mySupplementaryData; + /// }; + /// ``` + /// + /// Instances of this and derived classes are not expected to be created by + /// the user, instead they are directly constructed within a + /// OneShotAnalysisState. A OneShotAnalysisState can only contain one + /// extension with the given TypeID. Extensions can be obtained from a + /// OneShotAnalysisState instance. + /// + /// ```mlir + /// state.addExtension(/*myData=*/42); + /// MyExtension *ext = state.getExtension(); + /// ext->doSomething(); + /// ``` + class Extension { + // Allow OneShotAnalysisState to allocate Extensions. + friend class OneShotAnalysisState; + + public: + /// Base virtual destructor. + // Out-of-line definition ensures symbols are emitted in a single object + // file. + virtual ~Extension(); + + protected: + /// Constructs an extension of the given TransformState object. + Extension(OneShotAnalysisState &state) : state(state) {} + + /// Provides read-only access to the parent OneShotAnalysisState object. + const OneShotAnalysisState &getAnalysisState() const { return state; } + + private: + /// Back-reference to the state that is being extended. + OneShotAnalysisState &state; + }; + + /// Adds a new Extension of the type specified as template parameter, + /// constructing it with the arguments provided. The extension is owned by the + /// OneShotAnalysisState. It is expected that the state does not already have + /// an extension of the same type. Extension constructors are expected to take + /// a reference to OneShotAnalysisState as first argument, automatically + /// supplied by this call. + template Ty &addExtension(Args &&...args) { + static_assert( + std::is_base_of::value, + "only a class derived from OneShotAnalysisState::Extension is allowed"); + auto ptr = std::make_unique(*this, std::forward(args)...); + auto result = extensions.try_emplace(TypeID::get(), std::move(ptr)); + assert(result.second && "extension already added"); + return *static_cast(result.first->second.get()); + } + + /// Returns the extension of the specified type. + template Ty *getExtension() { + static_assert( + std::is_base_of::value, + "only a class derived from OneShotAnalysisState::Extension is allowed"); + auto iter = extensions.find(TypeID::get()); + if (iter == extensions.end()) + return nullptr; + return static_cast(iter->second.get()); + } + + /// Returns the extension of the specified type. + template const Ty *getExtension() const { + static_assert( + std::is_base_of::value, + "only a class derived from OneShotAnalysisState::Extension is allowed"); + auto iter = extensions.find(TypeID::get()); + if (iter == extensions.end()) + return nullptr; + return static_cast(iter->second.get()); + } + private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runOneShotBufferize` may access this object. @@ -177,6 +267,10 @@ /// A set of uses of tensors that have undefined contents. DenseSet undefinedTensorUses; + + /// Extensions attached to the TransformState, identified by the TypeID of + /// their type. Only one extension of any given type is allowed. + DenseMap> extensions; }; /// Analyze `op` and its nested ops. Bufferization decisions are stored in diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -296,12 +296,6 @@ return nullptr; } -void BufferizationOptions::addDialectStateInitializer( - StringRef name, const DialectStateInitFn &fn) { - stateInitializers.push_back( - [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); -} - //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -449,8 +443,8 @@ }); } -AnalysisState::AnalysisState(const BufferizationOptions &options) - : options(options) { +AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type) + : options(options), type(type) { for (const BufferizationOptions::AnalysisStateInitFn &fn : options.stateInitializers) fn(*this); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -101,22 +101,23 @@ /// Get FuncAnalysisState. static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state) { - Optional maybeState = - state.getDialectState( - func::FuncDialect::getDialectNamespace()); - assert(maybeState && "FuncAnalysisState does not exist"); - return **maybeState; + assert(isa(state) && "expected OneShotAnalysisState"); + auto *result = static_cast(state) + .getExtension(); + assert(result && "FuncAnalysisState does not exist"); + return *result; } /// Return the state (phase) of analysis of the FuncOp. static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp) { - Optional maybeState = - state.getDialectState( - func::FuncDialect::getDialectNamespace()); - if (!maybeState.has_value()) + if (!isa(state)) return FuncOpAnalysisState::NotAnalyzed; - const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps; + auto *funcState = static_cast(state) + .getExtension(); + if (!funcState) + return FuncOpAnalysisState::NotAnalyzed; + const auto &analyzedFuncOps = funcState->analyzedFuncOps; auto it = analyzedFuncOps.find(funcOp); if (it == analyzedFuncOps.end()) return FuncOpAnalysisState::NotAnalyzed; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -189,7 +189,8 @@ OneShotAnalysisState::OneShotAnalysisState( Operation *op, const OneShotBufferizationOptions &options) - : AnalysisState(options), aliasInfo(op) { + : AnalysisState(options, TypeID::get()), + aliasInfo(op) { // Set up alias sets for OpResults that must bufferize in-place. This should // be done before making any other bufferization decisions. op->walk([&](BufferizableOpInterface bufferizableOp) { @@ -321,6 +322,8 @@ return false; } +OneShotAnalysisState::Extension::~Extension() = default; + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -1071,11 +1074,6 @@ const auto &options = static_cast(state.getOptions()); - // Catch incorrect API usage. - assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) || - !options.bufferizeFunctionBoundaries) && - "must use ModuleBufferize to bufferize function boundaries"); - if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -76,32 +76,13 @@ /// A mapping of FuncOps to their callers. using FuncCallerMap = DenseMap>; -/// Get FuncAnalysisState. -static const FuncAnalysisState & -getFuncAnalysisState(const AnalysisState &state) { - Optional maybeState = - state.getDialectState( - func::FuncDialect::getDialectNamespace()); - assert(maybeState && "FuncAnalysisState does not exist"); - return **maybeState; -} - /// Get or create FuncAnalysisState. -static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { - return state.getOrCreateDialectState( - func::FuncDialect::getDialectNamespace()); -} - -/// Return the state (phase) of analysis of the FuncOp. -/// Used for debug modes. -LLVM_ATTRIBUTE_UNUSED -static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, - func::FuncOp funcOp) { - const FuncAnalysisState &funcState = getFuncAnalysisState(state); - auto it = funcState.analyzedFuncOps.find(funcOp); - if (it == funcState.analyzedFuncOps.end()) - return FuncOpAnalysisState::NotAnalyzed; - return it->second; +static FuncAnalysisState & +getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { + auto *result = state.getExtension(); + if (result) + return *result; + return state.addExtension(); } /// Return the unique ReturnOp that terminates `funcOp`. @@ -143,10 +124,9 @@ /// Store function BlockArguments that are equivalent to/aliasing a returned /// value in FuncAnalysisState. -static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, - OneShotAnalysisState &state) { - FuncAnalysisState &funcState = getFuncAnalysisState(state); - +static LogicalResult +aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + FuncAnalysisState &funcState) { // Support only single return-terminated block in the function. func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); @@ -190,10 +170,9 @@ /// Determine which FuncOp bbArgs are read and which are written. When run on a /// function with unknown ops, we conservatively assume that such ops bufferize /// to a read + write. -static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, - OneShotAnalysisState &state) { - FuncAnalysisState &funcState = getFuncAnalysisState(state); - +static LogicalResult +funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, + FuncAnalysisState &funcState) { // If the function has no body, conservatively assume that all args are // read + written. if (funcOp.getBody().empty()) { @@ -246,8 +225,8 @@ // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(func::FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - OneShotAnalysisState &state) { - FuncAnalysisState &funcState = getFuncAnalysisState(state); + OneShotAnalysisState &state, + FuncAnalysisState &funcState) { funcOp->walk([&](func::CallOp callOp) { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); @@ -360,7 +339,7 @@ static_cast(state.getOptions()); assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - FuncAnalysisState &funcState = getFuncAnalysisState(state); + FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); // A list of functions in the order in which they are analyzed + bufferized. @@ -382,15 +361,15 @@ funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, state); + equivalenceAnalysis(funcOp, aliasInfo, state, funcState); // Analyze funcOp. if (failed(analyzeOp(funcOp, state))) return failure(); // Run some extra function analyses. - if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) || - failed(funcOpBbArgReadWriteAnalysis(funcOp, state))) + if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || + failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) return failure(); // Mark op as fully analyzed.