diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -26,11 +26,11 @@ namespace bufferization { +class AnalysisState; class BufferizableOpInterface; -class BufferizationState; -struct DialectBufferizationState; +struct DialectAnalysisState; -/// Options for ComprehensiveBufferize. +/// Options for BufferizableOpInterface-based bufferization. struct BufferizationOptions { /// Allocator function: Generate a memref allocation with the given type, /// dynamic extents and alignment. @@ -43,11 +43,11 @@ /// Memcpy function: Generate a memcpy between two buffers. using MemCpyFn = std::function; - /// Initializer function for bufferization state. - using BufferizationStateInitFn = std::function; - /// Initializer function for dialect-specific bufferization state. + /// Initializer function for analysis state. + using AnalysisStateInitFn = std::function; + /// Initializer function for dialect-specific analysis state. using DialectStateInitFn = - std::function()>; + std::function()>; /// An op filter entry. Filters can be used to specify which ops should be /// processed by the bufferization. @@ -232,12 +232,12 @@ /// DENY-filtered and have at least one matching ALLOW filter are processed. SmallVector opFilter; - /// Initializer functions for bufferization state. These can be used to - /// initialize dialect-specific bufferization state. - SmallVector stateInitializers; + /// Initializer functions for analysis state. These can be used to + /// initialize dialect-specific analysis state. + SmallVector stateInitializers; - /// Add a bufferization state initializer that initializes the specified - /// dialect-specific bufferization state. + /// Add a analysis state initializer that initializes the specified + /// dialect-specific analysis state. void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn); private: @@ -265,21 +265,21 @@ /// Return `true` if the given value is a BlockArgument of a FuncOp. bool isFunctionArgument(Value value); -/// Dialect-specific bufferization state. Analysis/bufferization information +/// Dialect-specific analysis state. Analysis/bufferization information /// that is specific to ops from a certain dialect can be stored in derived /// variants of this struct. -struct DialectBufferizationState { - DialectBufferizationState() = default; +struct DialectAnalysisState { + DialectAnalysisState() = default; - virtual ~DialectBufferizationState() = default; + virtual ~DialectAnalysisState() = default; // Copying state is forbidden. Always pass as reference. - DialectBufferizationState(const DialectBufferizationState &) = delete; + DialectAnalysisState(const DialectAnalysisState &) = delete; }; -/// BufferizationState provides a variety of helper functions for dealing with -/// tensor values and memref buffers. -class BufferizationState { +/// AnalysisState provides a variety of helper functions for dealing with +/// tensor values. +class AnalysisState { public: /// Determine which OpOperand* will alias with `result` if the op is /// bufferized in place. Return an empty vector if the op is not bufferizable. @@ -348,15 +348,7 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0; - /// Return the buffer (memref) for a given OpOperand (tensor). Allocate - /// a new buffer and copy over data from the existing buffer if out-of-place - /// bufferization was decided. - FailureOr - getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - bool forceInPlace = false, - Optional customCopyInsertionPoint = None) const; - - /// Return dialect-specific bufferization state. + /// Return dialect-specific analysis state. template Optional getDialectState(StringRef name) const { auto it = dialectState.find(name); @@ -365,7 +357,7 @@ return static_cast(it->getSecond().get()); } - /// Return dialect-specific bufferization state or create one if none exists. + /// Return dialect-specific analysis state or create one if none exists. template StateT &getOrCreateDialectState(StringRef name) { // Create state if it does not exist yet. @@ -375,7 +367,7 @@ } void insertDialectState(StringRef name, - std::unique_ptr state) { + std::unique_ptr state) { assert(!dialectState.count(name) && "dialect state already initialized"); dialectState[name] = std::move(state); } @@ -384,31 +376,31 @@ const BufferizationOptions &getOptions() const { return options; } protected: - explicit BufferizationState(const BufferizationOptions &options); + explicit AnalysisState(const BufferizationOptions &options); - // BufferizationState should be passed as a reference. - BufferizationState(const BufferizationState &) = delete; + // AnalysisState should be passed as a reference. + AnalysisState(const AnalysisState &) = delete; - ~BufferizationState() = default; + ~AnalysisState() = default; private: - /// Dialect-specific bufferization state. - DenseMap> dialectState; + /// Dialect-specific analysis state. + DenseMap> dialectState; /// A reference to current bufferization options. const BufferizationOptions &options; }; -/// This a "no analysis, always copy" BufferizationState. In the absence of an +/// This a "no analysis, always copy" AnalysisState. In the absence of an /// analysis, a buffer must be copied each time it is written to. Therefore, all /// OpOperands that bufferize to a memory write must bufferize out-of-place. -class AlwaysCopyBufferizationState : public BufferizationState { +class AlwaysCopyAnalysisState : public AnalysisState { public: - explicit AlwaysCopyBufferizationState(const BufferizationOptions &options); + explicit AlwaysCopyAnalysisState(const BufferizationOptions &options); - AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; + AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; - virtual ~AlwaysCopyBufferizationState() = default; + virtual ~AlwaysCopyAnalysisState() = default; /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpOperand &opOperand) const override; @@ -417,6 +409,35 @@ bool areEquivalentBufferizedValues(Value v1, Value v2) const override; }; +/// BufferizationState provides helper functions for performing bufferization +/// rewrites and handling memref buffers. +struct BufferizationState { + BufferizationState(const AnalysisState &analysisState) + : analysisState(analysisState) {} + + /// Return the buffer (memref) for a given OpOperand (tensor). Allocate + /// a new buffer and copy over data from the existing buffer if out-of-place + /// bufferization was decided. + FailureOr + getBuffer(RewriterBase &rewriter, OpOperand &opOperand, + bool forceInPlace = false, + Optional customCopyInsertionPoint = None) const; + + /// Return a reference to the BufferizationOptions. + const BufferizationOptions &getOptions() const { + return analysisState.getOptions(); + } + + const AnalysisState &getAnalysisState() const { return analysisState; } + +protected: + // BufferizationState should be passed as a reference. + BufferizationState(const BufferizationState &) = delete; + +private: + const AnalysisState &analysisState; +}; + /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, @@ -503,39 +524,38 @@ : public BufferizableOpInterface::ExternalModel< AllocationHoistingBarrierOnly, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return {}; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::None; } bool isWritable(Operation *op, Value value, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { return failure(); } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -33,7 +33,7 @@ /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryRead", /*args=*/(ins "OpOperand &":$opOperand, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -62,7 +62,7 @@ /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryWrite", /*args=*/(ins "OpOperand &":$opOperand, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -85,7 +85,7 @@ /*retType=*/"bool", /*methodName=*/"isMemoryWrite", /*args=*/(ins "OpResult":$opResult, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ auto bufferizableOp = @@ -112,7 +112,7 @@ /*retType=*/"bool", /*methodName=*/"mustBufferizeInPlace", /*args=*/(ins "OpOperand &":$opOperand, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return false; @@ -127,7 +127,7 @@ /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpResult", /*args=*/(ins "OpOperand &":$opOperand, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -151,7 +151,7 @@ /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", /*args=*/(ins "OpResult":$opResult, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getType().isa() && @@ -185,7 +185,7 @@ /*retType=*/"BufferRelation", /*methodName=*/"bufferRelation", /*args=*/(ins "OpResult":$opResult, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpResults @@ -220,7 +220,7 @@ /*retType=*/"LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "RewriterBase &":$rewriter, - "const BufferizationState &":$state), + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); @@ -246,7 +246,7 @@ /*retType=*/"bool", /*methodName=*/"isWritable", /*args=*/(ins "Value":$value, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return value.isa(); @@ -285,7 +285,7 @@ /*methodName=*/"isNotConflicting", /*args=*/(ins "OpOperand *":$uRead, "OpOperand *":$uWrite, - "const BufferizationState &":$state), + "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return false; @@ -302,7 +302,7 @@ }], /*retType=*/"LogicalResult", /*methodName=*/"verifyAnalysis", - /*args=*/(ins "const BufferizationState &":$state), + /*args=*/(ins "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return success(); @@ -318,7 +318,7 @@ /// /// Examples of such ops are `tensor.extract_slice` and `tensor.cast`. bool bufferizesToAliasOnly(OpOperand &opOperand, - const BufferizationState &state) { + const AnalysisState &state) { auto bufferizableOp = cast(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -125,7 +125,7 @@ // results as not writable enforces a buffer copy and has the same effect. LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { // to_tensor cannot be bufferized. However, other ops that are using // to_tensor's result will eventually be bufferized. At that point, they // will start using to_tensor's memref operand. Once all users of @@ -136,7 +136,7 @@ return failure(); } - bool isWritable(Value value, const BufferizationState &state) const { + bool isWritable(Value value, const AnalysisState &state) const { // It is unknown whether the memref operand is writable or not. return false; } @@ -194,30 +194,30 @@ // but such IR may no longer be analyzable by One-Shot analysis. bool bufferizesToMemoryRead(OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // It is unknown whether the resulting memref will be read or not. return true; } bool bufferizesToMemoryWrite(OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // It is unknown whether the resulting MemRef will be written or not. return true; } bool mustBufferizeInPlace(OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // ToMemrefOps always bufferize inplace. return true; } SmallVector getAliasingOpResult( - OpOperand &opOperand, const BufferizationState &state) const { + OpOperand &opOperand, const AnalysisState &state) const { return {}; } LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationState &state); + BufferizationState &state); }]; let assemblyFormat = "$tensor attr-dict `:` type($memref)"; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -28,7 +28,8 @@ namespace mlir { namespace bufferization { -class BufferizationState; +class AnalysisState; +struct BufferizationState; struct BufferizationOptions; /// A helper type converter class that automatically populates the relevant @@ -67,7 +68,14 @@ /// layouts after transformations. Combinations of memref.cast + /// canonicalization are responsible for clean ups. // TODO: Extract `options` from `state` and pass as separate argument. -LogicalResult bufferizeOp(Operation *op, const BufferizationState &state); +LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState); + +/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Reuse an existing `BufferizationState`. +/// +/// Note: This function overload is useful for extending the bufferization. +LogicalResult bufferizeOp(Operation *op, + BufferizationState &bufferizationState); /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// Buffers are duplicated and copied before any tensor use that bufferizes to @@ -77,11 +85,6 @@ /// can be used to implement partial bufferization passes. LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options); -/// Populate the pattern set with a pattern that bufferizes ops that implement -/// `BufferizableOpInterface`. -void populateBufferizationPattern(const BufferizationState &state, - RewritePatternSet &patterns); - BufferizationOptions getPartialBufferizationOptions(); } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -16,9 +16,9 @@ namespace mlir { namespace bufferization { -class AnalysisBufferizationState; +struct OneShotBufferizationOptions; class BufferizationAliasInfo; -struct AnalysisBufferizationOptions; +class OneShotAnalysisState; /// PostAnalysisStepFns can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used to @@ -26,14 +26,14 @@ /// must keep `aliasInfo` consistent. Newly created operations and operations /// that should be re-analyzed must be added to `newOps`. using PostAnalysisStepFn = std::function &)>; using PostAnalysisStepList = SmallVector; /// Options for analysis-enabled bufferization. -struct AnalysisBufferizationOptions : public BufferizationOptions { - AnalysisBufferizationOptions() = default; +struct OneShotBufferizationOptions : public BufferizationOptions { + OneShotBufferizationOptions() = default; /// Register a "post analysis" step. Such steps are executed after the /// analysis, but before bufferization. @@ -68,7 +68,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpOperand &operand, BufferizationState &state); + void bufferizeInPlace(OpOperand &operand, AnalysisState &state); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpOperand &operand); @@ -135,14 +135,14 @@ /// State for analysis-enabled bufferization. This class keeps track of alias /// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize /// in-place. -class AnalysisBufferizationState : public BufferizationState { +class OneShotAnalysisState : public AnalysisState { public: - AnalysisBufferizationState(Operation *op, - const AnalysisBufferizationOptions &options); + OneShotAnalysisState(Operation *op, + const OneShotBufferizationOptions &options); - AnalysisBufferizationState(const AnalysisBufferizationState &) = delete; + OneShotAnalysisState(const OneShotAnalysisState &) = delete; - virtual ~AnalysisBufferizationState() = default; + virtual ~OneShotAnalysisState() = default; /// Return a reference to the BufferizationAliasInfo. BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } @@ -161,11 +161,11 @@ /// Analyze `op` and its nested ops. Bufferization decisions are stored in /// `state`. -LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state); +LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state); /// Run One-Shot Bufferize on the given op: Analysis + Bufferization LogicalResult runOneShotBufferize(Operation *op, - const AnalysisBufferizationOptions &options); + const OneShotBufferizationOptions &options); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -5,7 +5,7 @@ namespace mlir { namespace bufferization { -struct AnalysisBufferizationOptions; +struct OneShotBufferizationOptions; //===----------------------------------------------------------------------===// // Passes @@ -37,7 +37,7 @@ /// Create a pass that bufferizes all ops that implement BufferizableOpInterface /// with One-Shot Bufferize and the specified bufferization options. std::unique_ptr -createOneShotBufferizePass(const AnalysisBufferizationOptions &options); +createOneShotBufferizePass(const OneShotBufferizationOptions &options); /// Creates a pass that promotes heap-based allocations to stack-based ones. /// Only buffers smaller than the provided size are promoted. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h @@ -18,7 +18,7 @@ class ModuleOp; namespace bufferization { -struct AnalysisBufferizationOptions; +struct OneShotBufferizationOptions; } // namespace bufferization namespace linalg { @@ -29,7 +29,7 @@ /// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. LogicalResult runModuleBufferize(ModuleOp moduleOp, - bufferization::AnalysisBufferizationOptions options); + bufferization::OneShotBufferizationOptions options); namespace std_ext { diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,7 +19,7 @@ namespace mlir { namespace bufferization { -struct AnalysisBufferizationOptions; +struct OneShotBufferizationOptions; } // namespace bufferization std::unique_ptr createConvertElementwiseToLinalgPass(); @@ -64,7 +64,7 @@ /// with the 'inplaceable' attribute. std::unique_ptr createLinalgComprehensiveModuleBufferizePass(); std::unique_ptr createLinalgComprehensiveModuleBufferizePass( - const bufferization::AnalysisBufferizationOptions &options); + const bufferization::OneShotBufferizationOptions &options); /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h @@ -36,7 +36,7 @@ /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. LogicalResult -eliminateInitTensors(Operation *op, bufferization::BufferizationState &state, +eliminateInitTensors(Operation *op, bufferization::AnalysisState &state, bufferization::BufferizationAliasInfo &aliasInfo, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, SmallVector &newOps); @@ -45,7 +45,7 @@ /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). LogicalResult insertSliceAnchoredInitTensorEliminationStep( - Operation *op, bufferization::BufferizationState &state, + Operation *op, bufferization::AnalysisState &state, bufferization::BufferizationAliasInfo &aliasInfo, SmallVector &newOps); diff --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h @@ -9,16 +9,9 @@ #ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H #define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" - namespace mlir { class DialectRegistry; -namespace bufferization { -class BufferizationState; -class BufferizationAliasInfo; -} // namespace bufferization - namespace scf { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace scf diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -23,7 +23,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto constantOp = cast(op); // Only ranked tensors are supported. @@ -49,7 +49,7 @@ } bool isWritable(Operation *op, Value value, - const BufferizationState &state) const { + const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. assert(value.isa()); return false; @@ -60,28 +60,27 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {op->getResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto castOp = cast(op); Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); @@ -106,30 +105,29 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {op->getOpResult(0) /*result*/}; } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return {&op->getOpOperand(1) /*true_value*/, &op->getOpOperand(2) /*false_value*/}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto selectOp = cast(op); // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. @@ -147,7 +145,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::None; } }; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -67,7 +67,7 @@ void BufferizationOptions::addDialectStateInitializer( StringRef name, const DialectStateInitFn &fn) { stateInitializers.push_back( - [=](BufferizationState &state) { state.insertDialectState(name, fn()); }); + [=](AnalysisState &state) { state.insertDialectState(name, fn()); }); } //===----------------------------------------------------------------------===// @@ -85,7 +85,7 @@ /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector -BufferizationState::getAliasingOpOperand(OpResult result) const { +AnalysisState::getAliasingOpOperand(OpResult result) const { if (Operation *op = result.getDefiningOp()) if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.getAliasingOpOperand(result, *this); @@ -95,7 +95,7 @@ /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector -BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { +AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.getAliasingOpResult(opOperand, *this); @@ -104,7 +104,7 @@ /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. -bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const { +bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); @@ -116,7 +116,7 @@ /// Return true if `opOperand` bufferizes to a memory write. Return /// `true` if the op is not bufferizable. -bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const { +bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); @@ -128,7 +128,7 @@ /// Return true if `opOperand` does neither read nor write but bufferizes to an /// alias. Return false if the op is not bufferizable. -bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const { +bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); @@ -141,7 +141,7 @@ /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). -bool BufferizationState::isValueRead(Value value) const { +bool AnalysisState::isValueRead(Value value) const { assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) @@ -165,7 +165,7 @@ // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any // further. -llvm::SetVector BufferizationState::findValueInReverseUseDefChain( +llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -193,7 +193,7 @@ // Find the Values of the last preceding write of a given Value. llvm::SetVector -BufferizationState::findLastPrecedingWrite(Value value) const { +AnalysisState::findLastPrecedingWrite(Value value) const { return findValueInReverseUseDefChain(value, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) @@ -205,9 +205,9 @@ }); } -BufferizationState::BufferizationState(const BufferizationOptions &options) +AnalysisState::AnalysisState(const BufferizationOptions &options) : options(options) { - for (const BufferizationOptions::BufferizationStateInitFn &fn : + for (const BufferizationOptions::AnalysisStateInitFn &fn : options.stateInitializers) fn(*this); } @@ -246,13 +246,14 @@ FailureOr BufferizationState::getBuffer( RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace, Optional customCopyInsertionPoint) const { + const BufferizationOptions &options = analysisState.getOptions(); OpBuilder::InsertionGuard guard(rewriter); Operation *op = opOperand.getOwner(); Location loc = op->getLoc(); Value operand = opOperand.get(); Value operandBuffer = lookupBuffer(rewriter, operand, options); - if (forceInPlace || isInPlace(opOperand)) + if (forceInPlace || analysisState.isInPlace(opOperand)) return operandBuffer; // Bufferizing out-of-place: Allocate a new buffer. @@ -269,22 +270,26 @@ // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. - SetVector lastWrites = findLastPrecedingWrite(operand); + SetVector lastWrites = analysisState.findLastPrecedingWrite(operand); if (llvm::none_of(lastWrites, [&](Value lastWrite) { if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) return bufferizableOp.isMemoryWrite(lastWrite.cast(), - *this); + analysisState); return true; })) return resultBuffer; // Do not copy if the copied data is never read. - SmallVector aliasingOpResults = getAliasingOpResult(opOperand); - if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) && - llvm::none_of(aliasingOpResults, - [&](OpResult opResult) { return isValueRead(opResult); })) + SmallVector aliasingOpResults = + analysisState.getAliasingOpResult(opOperand); + if (!aliasingOpResults.empty() && + !analysisState.bufferizesToMemoryRead(opOperand) && + llvm::none_of(aliasingOpResults, [&](OpResult opResult) { + return analysisState.isValueRead(opResult); + })) return resultBuffer; // Do not copy if this op does not read the data, but writes it. - if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) + if (analysisState.bufferizesToMemoryWrite(opOperand) && + !analysisState.bufferizesToMemoryRead(opOperand)) return resultBuffer; if (customCopyInsertionPoint) { @@ -330,20 +335,20 @@ rewriter.replaceOp(op, replacements); } -AlwaysCopyBufferizationState::AlwaysCopyBufferizationState( +AlwaysCopyAnalysisState::AlwaysCopyAnalysisState( const BufferizationOptions &options) - : BufferizationState(options) {} + : AnalysisState(options) {} /// Return `true` if the given OpResult has been decided to bufferize inplace. -bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const { +bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const { // OpOperands that bufferize to a memory write are out-of-place, i.e., an // alloc and copy is inserted. return !bufferizesToMemoryWrite(opOperand); } /// Return true if `v1` and `v2` bufferize to equivalent buffers. -bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues( - Value v1, Value v2) const { +bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1, + Value v2) const { // There is no analysis, so we do not know if the values are equivalent. The // conservative answer is "false". return false; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -349,7 +349,7 @@ } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, - const BufferizationState &state) { + BufferizationState &state) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. return foldToMemrefToTensorPair(rewriter, *this); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -153,7 +153,7 @@ : public OneShotBufferizeBase { OneShotBufferizePass() : OneShotBufferizeBase() {} - explicit OneShotBufferizePass(const AnalysisBufferizationOptions &options) + explicit OneShotBufferizePass(const OneShotBufferizationOptions &options) : options(options) {} void getDependentDialects(DialectRegistry ®istry) const override { @@ -161,7 +161,7 @@ } void runOnOperation() override { - AnalysisBufferizationOptions opt; + OneShotBufferizationOptions opt; if (!options) { // Make new bufferization options if none were provided when creating the // pass. @@ -209,7 +209,7 @@ } private: - llvm::Optional options; + llvm::Optional options; }; } // namespace @@ -218,7 +218,7 @@ } std::unique_ptr mlir::bufferization::createOneShotBufferizePass( - const AnalysisBufferizationOptions &options) { + const OneShotBufferizationOptions &options) { return std::make_unique(options); } @@ -243,23 +243,25 @@ /// Rewrite pattern that bufferizes bufferizable ops. struct BufferizationPattern : public OpInterfaceRewritePattern { - BufferizationPattern(MLIRContext *context, const BufferizationState &state, + BufferizationPattern(MLIRContext *context, BufferizationState &state, PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), - state(state) {} + state(&state) {} LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp, PatternRewriter &rewriter) const override { + const BufferizationOptions &options = state->getOptions(); + // No tensors => no buffers. if (!hasTensorSemantics(bufferizableOp.getOperation())) return failure(); - if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation())) + if (!options.isOpAllowed(bufferizableOp.getOperation())) return failure(); - return bufferizableOp.bufferize(rewriter, state); + return bufferizableOp.bufferize(rewriter, *state); } private: - const BufferizationState &state; + BufferizationState *const state; }; /// Check the result of bufferization. Return an error if an op was not @@ -298,10 +300,17 @@ } LogicalResult bufferization::bufferizeOp(Operation *op, - const BufferizationState &state) { + const AnalysisState &analysisState) { + BufferizationState bufferizationState(analysisState); + return bufferizeOp(op, bufferizationState); +} + +LogicalResult +bufferization::bufferizeOp(Operation *op, + BufferizationState &bufferizationState) { // Bufferize the op and its nested ops. RewritePatternSet patterns(op->getContext()); - populateBufferizationPattern(state, patterns); + patterns.add(patterns.getContext(), bufferizationState); // Bufferize ops top-to-bottom. When creating a new op, we should ideally // know the exact memref type of all operands. Otherwise, we have to use a @@ -323,21 +332,21 @@ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) return failure(); - return checkBufferizationResult(op, state.getOptions()); + return checkBufferizationResult(op, bufferizationState.getOptions()); } namespace { -/// This a "no analysis, always copy" BufferizationState. In the absence of an +/// This a "no analysis, always copy" AnalysisState. In the absence of an /// analysis, a buffer must be copied each time it is written to. Therefore, all /// OpOperands that bufferize to a memory write must bufferize out-of-place. -class AlwaysCopyBufferizationState : public BufferizationState { +class AlwaysCopyAnalysisState : public AnalysisState { public: - AlwaysCopyBufferizationState(const BufferizationOptions &options) - : BufferizationState(options) {} + AlwaysCopyAnalysisState(const BufferizationOptions &options) + : AnalysisState(options) {} - AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete; + AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete; - virtual ~AlwaysCopyBufferizationState() = default; + virtual ~AlwaysCopyAnalysisState() = default; /// Return `true` if the given OpResult has been decided to bufferize inplace. bool isInPlace(OpOperand &opOperand) const override { @@ -357,15 +366,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op, const BufferizationOptions &options) { - AlwaysCopyBufferizationState state(options); + AlwaysCopyAnalysisState state(options); return bufferizeOp(op, state); } -void bufferization::populateBufferizationPattern( - const BufferizationState &state, RewritePatternSet &patterns) { - patterns.add(patterns.getContext(), state); -} - BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowReturnMemref = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -26,7 +26,7 @@ // ops) and then bufferizes it. // // Inplace bufferization decisions are passed from the analysis to the -// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`. +// bufferization phase via `AnalysisState` and `BufferizationAliasInfo`. // They can be printed for debugging purposes with `testAnalysisOnly`. // // Ops that do not implement `BufferizableOpInterface` can be analyzed but are @@ -138,7 +138,7 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, - BufferizationState &state) { + AnalysisState &state) { markInPlace(operand); for (OpResult result : state.getAliasingOpResult(operand)) aliasInfo.unionSets(result, operand.get()); @@ -182,12 +182,12 @@ } //===----------------------------------------------------------------------===// -// AnalysisBufferizationState +// OneShotAnalysisState //===----------------------------------------------------------------------===// -AnalysisBufferizationState::AnalysisBufferizationState( - Operation *op, const AnalysisBufferizationOptions &options) - : BufferizationState(options), aliasInfo(op) { +OneShotAnalysisState::OneShotAnalysisState( + Operation *op, const OneShotBufferizationOptions &options) + : AnalysisState(options), aliasInfo(op) { // Set up alias sets for OpResults that must bufferize in-place. This should // be done before making any other bufferization decisions. op->walk([&](BufferizableOpInterface bufferizableOp) { @@ -206,12 +206,12 @@ }); } -bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const { +bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { return aliasInfo.isInPlace(opOperand); } -bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1, - Value v2) const { +bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, + Value v2) const { return aliasInfo.areEquivalentBufferizedValues(v1, v2); } @@ -222,7 +222,7 @@ /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { // OpOperands that do not bufferize to a memory write do not write in-place. if (!state.bufferizesToMemoryWrite(opOperand)) return false; @@ -234,7 +234,7 @@ /// is not writable. static bool aliasesNonWritableBuffer(Value value, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { bool foundNonWritableBuffer = false; aliasInfo.applyOnAliases(value, [&](Value v) { // Query BufferizableOpInterface to see if the value is writable. @@ -260,7 +260,7 @@ /// to some buffer write. static bool aliasesInPlaceWrite(Value value, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { bool foundInplaceWrite = false; aliasInfo.applyOnAliases(value, [&](Value v) { for (auto &use : v.getUses()) { @@ -331,7 +331,7 @@ static bool hasReadAfterWriteInterference( const DenseSet &usesRead, const DenseSet &usesWrite, const DominanceInfo &domInfo, - BufferizationState &state, const BufferizationAliasInfo &aliasInfo) { + AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); for (OpOperand *uRead : usesRead) { @@ -452,7 +452,7 @@ /// OpResult. In that case, only the consistency of bufferization decisions /// involving aliases of the given OpOperand are checked. static bool wouldCreateReadAfterWriteInterference( - OpOperand &operand, const DominanceInfo &domInfo, BufferizationState &state, + OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, const BufferizationAliasInfo &aliasInfo, bool checkConsistencyOnly = false) { // Helper function to iterate on aliases of `root` and capture the reads. @@ -495,7 +495,7 @@ static bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or // 2. A constant op. @@ -520,8 +520,8 @@ /// Determine if `operand` can be bufferized in-place. static LogicalResult bufferizableInPlaceAnalysisImpl( - OpOperand &operand, BufferizationAliasInfo &aliasInfo, - BufferizationState &state, const DominanceInfo &domInfo) { + OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state, + const DominanceInfo &domInfo) { bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); @@ -554,7 +554,7 @@ /// RaW dependence violations. static LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, - BufferizationState &state, + AnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { if (analysisFuzzerSeed) { @@ -587,7 +587,7 @@ /// Analyze all ops that are contained in `op`. static LogicalResult inPlaceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, - BufferizationState &state, + AnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { // Collect ops so we can build our own reverse traversal. @@ -605,7 +605,7 @@ /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { for (Operation *op : ops) if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) for (OpResult opResult : op->getOpResults()) @@ -622,7 +622,7 @@ /// in `op`. static void equivalenceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { // Traverse ops in PostOrder: Nested ops first, then enclosing ops. SmallVector ops; op->walk([&](Operation *op) { @@ -638,7 +638,7 @@ /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, - BufferizationState &state, + AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); Operation *inconsistentOp = nullptr; @@ -668,7 +668,7 @@ static void annotateOpsWithBufferizationMarkers(Operation *op, const BufferizationAliasInfo &aliasInfo, - BufferizationState &state) { + AnalysisState &state) { op->walk([&](Operation *op) { if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) for (OpOperand &opOperand : op->getOpOperands()) @@ -701,7 +701,7 @@ // for aliasing values because the analysis is a maybe-alias analysis and we // need a must-alias analysis here. static LogicalResult -assertDestinationPassingStyle(Operation *op, BufferizationState &state, +assertDestinationPassingStyle(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { LogicalResult status = success(); @@ -748,11 +748,11 @@ } LogicalResult bufferization::analyzeOp(Operation *op, - AnalysisBufferizationState &state) { + OneShotAnalysisState &state) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); const auto &options = - static_cast(state.getOptions()); + static_cast(state.getOptions()); if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); @@ -796,9 +796,10 @@ return success(!failedAnalysis); } -LogicalResult bufferization::runOneShotBufferize( - Operation *op, const AnalysisBufferizationOptions &options) { - AnalysisBufferizationState state(op, options); +LogicalResult +bufferization::runOneShotBufferize(Operation *op, + const OneShotBufferizationOptions &options) { + OneShotAnalysisState state(op, options); if (failed(analyzeOp(op, state))) return failure(); if (options.testAnalysisOnly) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -17,7 +17,7 @@ // // After analyzing a FuncOp, additional information about its bbArgs is // gathered through PostAnalysisStepFns and stored in -// `ModuleBufferizationState`. +// `ModuleAnalysisState`. // // * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each // tensor return value (if any). @@ -90,9 +90,9 @@ /// The state of analysis of a FuncOp. enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; -/// Extra bufferization state that is required for bufferization of function +/// Extra analysis state that is required for bufferization of function /// boundaries. -struct ModuleBufferizationState : public DialectBufferizationState { +struct ModuleAnalysisState : public DialectAnalysisState { /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. DenseMap> equivalentFuncArgs; @@ -117,28 +117,26 @@ }; } // namespace -/// Get ModuleBufferizationState. -static const ModuleBufferizationState & -getModuleBufferizationState(const BufferizationState &state) { - Optional maybeState = - state.getDialectState( +/// Get ModuleAnalysisState. +static const ModuleAnalysisState & +getModuleAnalysisState(const AnalysisState &state) { + Optional maybeState = + state.getDialectState( func::FuncDialect::getDialectNamespace()); - assert(maybeState.hasValue() && "ModuleBufferizationState does not exist"); + assert(maybeState.hasValue() && "ModuleAnalysisState does not exist"); return **maybeState; } -/// Get or create ModuleBufferizationState. -static ModuleBufferizationState & -getModuleBufferizationState(BufferizationState &state) { - return state.getOrCreateDialectState( +/// Get or create ModuleAnalysisState. +static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) { + return state.getOrCreateDialectState( func::FuncDialect::getDialectNamespace()); } /// Return the state (phase) of analysis of the FuncOp. -static FuncOpAnalysisState -getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) { - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); +static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, + FuncOp funcOp) { + const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); auto it = moduleState.analyzedFuncOps.find(funcOp); if (it == moduleState.analyzedFuncOps.end()) return FuncOpAnalysisState::NotAnalyzed; @@ -183,12 +181,12 @@ } /// Store function BlockArguments that are equivalent to a returned value in -/// ModuleBufferizationState. +/// ModuleAnalysisState. static LogicalResult -equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state, +equivalentFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + ModuleAnalysisState &moduleState = getModuleAnalysisState(state); // Support only single return-terminated block in the function. auto funcOp = cast(op); @@ -213,7 +211,7 @@ /// Return true if the buffer of the given tensor value is written to. Must not /// be called for values inside not yet analyzed functions. (Post-analysis /// steps do not have to be run yet, i.e., "in progress" is also OK.) -static bool isValueWritten(Value value, const BufferizationState &state, +static bool isValueWritten(Value value, const AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { #ifndef NDEBUG assert(value.getType().isa() && "expected TensorType"); @@ -259,10 +257,10 @@ /// PostAnalysisStepFn is run on a function with unknown ops, it will /// conservatively assume that such ops bufferize to a read + write. static LogicalResult -funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state, +funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + ModuleAnalysisState &moduleState = getModuleAnalysisState(state); auto funcOp = cast(op); // If the function has no body, conservatively assume that all args are @@ -349,7 +347,7 @@ // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - ModuleBufferizationState &moduleState) { + ModuleAnalysisState &moduleState) { funcOp->walk([&](func::CallOp callOp) { FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called FuncOp"); @@ -387,7 +385,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, RewriterBase &rewriter, BufferizationState &state) { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = + getModuleAnalysisState(state.getAnalysisState()); // If nothing to do then we are done. if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && @@ -439,8 +438,9 @@ } // If return operand is equivalent to some bbArg, no need to return it. - if (moduleState.equivalentFuncArgs[funcOp].count( - returnOperand.getOperandNumber())) + auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp); + if (funcOpIt != moduleState.equivalentFuncArgs.end() && + funcOpIt->second.count(returnOperand.getOperandNumber())) continue; // Cast values at the call site if necessary. @@ -674,7 +674,7 @@ /// Return the index of the bbArg in the given FuncOp that is equivalent to the /// specified return value (if any). static Optional -getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state, +getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state, int64_t returnValIdx) { auto funcOpIt = state.equivalentFuncArgs.find(funcOp); if (funcOpIt == state.equivalentFuncArgs.end()) @@ -693,13 +693,12 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is read. return true; @@ -709,13 +708,12 @@ } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is written. return true; @@ -724,14 +722,12 @@ funcOp.getArgument(opOperand.getOperandNumber())); } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); SmallVector result; for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults(); @@ -746,12 +742,11 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); // TODO: We should be looking for aliasing block arguments here. The current // condition is actually stronger than neccesary. Once we check for aliasing @@ -766,7 +761,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } @@ -774,14 +769,14 @@ /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { func::CallOp callOp = cast(op); unsigned numResults = callOp.getNumResults(); unsigned numOperands = callOp->getNumOperands(); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); + const ModuleAnalysisState &moduleState = + getModuleAnalysisState(state.getAnalysisState()); // Result types of the bufferized CallOp. SmallVector resultTypes; @@ -906,23 +901,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { #ifndef NDEBUG auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && @@ -935,13 +929,13 @@ struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { return failure(); } /// Return `true` if the given function argument is writable. bool isWritable(Operation *op, Value value, - const BufferizationState &state) const { + const AnalysisState &state) const { auto funcOp = cast(op); BlockArgument bbArg = value.dyn_cast(); assert(bbArg && "expected BlockArgument"); @@ -982,9 +976,8 @@ } /// Annotate the IR with the result of the analysis. For testing/debugging only. -static void -annotateOpsWithBufferizationMarkers(FuncOp funcOp, - const BufferizationState &state) { +static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, + const AnalysisState &state) { auto bufferizableOp = cast(funcOp.getOperation()); for (BlockArgument bbArg : funcOp.getArguments()) if (bbArg.getType().isa()) @@ -992,11 +985,12 @@ } LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( - ModuleOp moduleOp, AnalysisBufferizationOptions options) { + ModuleOp moduleOp, OneShotBufferizationOptions options) { IRRewriter rewriter(moduleOp.getContext()); - AnalysisBufferizationState state(moduleOp, options); - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); + OneShotAnalysisState analysisState(moduleOp, options); + BufferizationState bufferizationState(analysisState); + ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState); + BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps, moduleState.callerMap))) @@ -1016,7 +1010,7 @@ moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; // Analyze funcOp. - if (failed(analyzeOp(funcOp, state))) + if (failed(analyzeOp(funcOp, analysisState))) return failure(); // Gather equivalence info for CallOps. @@ -1028,7 +1022,7 @@ // Add annotations to function arguments. if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(funcOp, state); + annotateOpsWithBufferizationMarkers(funcOp, analysisState); } if (options.testAnalysisOnly) @@ -1040,7 +1034,7 @@ if (funcOp.getBody().empty()) continue; - if (failed(bufferizeOp(funcOp, state))) + if (failed(bufferizeOp(funcOp, bufferizationState))) return failure(); } @@ -1048,7 +1042,7 @@ for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state))) + if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState))) return failure(); if (!options.allowReturnMemref && diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -25,7 +25,7 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, - const BufferizationState &state) { + BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -56,7 +56,7 @@ SmallVector newOutputBuffers; for (OpResult opResult : op->getOpResults()) { SmallVector aliasingOpOperands = - state.getAliasingOpOperand(opResult); + state.getAnalysisState().getAliasingOpOperand(opResult); assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); FailureOr resultBuffer = state.getBuffer(rewriter, *aliasingOpOperands.front()); @@ -156,14 +156,14 @@ : public BufferizableOpInterface::ExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // Operand is read if it is used in the computation. auto genericOp = cast(op); return genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // Operand is written to if it has an aliasing OpResult. auto bufferizableOp = cast(op); return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); @@ -171,7 +171,7 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { auto genericOp = cast(op); // By default, the i-th OpResult may alias with the i-th "out" tensor. @@ -188,9 +188,8 @@ return {}; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { auto genericOp = cast(op); // By default, the i-th "out" tensor may alias with the i-th OpResult. @@ -209,12 +208,12 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { return bufferizeLinalgOp(rewriter, cast(op), state); } }; @@ -223,13 +222,13 @@ : public BufferizableOpInterface::ExternalModel { bool isMemoryWrite(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // InitTensorOps allocate but do not write. return false; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto initTensorOp = cast(op); // The InitTensorOp may have been eliminated. @@ -345,7 +344,7 @@ /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult mlir::linalg::eliminateInitTensors( - Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, + Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, SmallVector &newOps) { OpBuilder b(op->getContext()); @@ -447,7 +446,7 @@ /// Note that the newly inserted ExtractSliceOp may have to bufferize /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( - Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, + Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { return eliminateInitTensors( op, state, aliasInfo, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -40,7 +40,7 @@ const LinalgComprehensiveModuleBufferize &p) = default; explicit LinalgComprehensiveModuleBufferize( - const AnalysisBufferizationOptions &options) + const OneShotBufferizationOptions &options) : options(options) {} void runOnOperation() override; @@ -61,7 +61,7 @@ } private: - llvm::Optional options; + llvm::Optional options; }; } // namespace @@ -81,7 +81,7 @@ } void LinalgComprehensiveModuleBufferize::runOnOperation() { - AnalysisBufferizationOptions opt; + OneShotBufferizationOptions opt; if (!options) { // Make new bufferization options if none were provided when creating the // pass. @@ -129,6 +129,6 @@ } std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass( - const AnalysisBufferizationOptions &options) { + const OneShotBufferizationOptions &options) { return std::make_unique(options); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -40,7 +40,7 @@ scf::ExecuteRegionOp> { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value @@ -60,7 +60,7 @@ // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. bool isMemoryWrite(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is @@ -70,7 +70,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto executeRegionOp = cast(op); // Compute new result types. @@ -125,7 +125,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } }; @@ -135,7 +135,7 @@ : public BufferizableOpInterface::ExternalModel { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else @@ -152,7 +152,7 @@ // allowed at the moment, we should never encounter scf.ifs that yield // unmodified tensors. Such scf.yield ops could just fold away. bool isMemoryWrite(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // IfOp results are always considered memory writes in the analysis. This // design decision simplifies the analysis considerably. E.g., consider the // following test case: @@ -179,7 +179,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto ifOp = cast(op); // Compute new types of the bufferized scf.if op. @@ -244,7 +244,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // IfOp results are equivalent to their corresponding yield values if both // yield values are equivalent to each other. auto bufferizableOp = cast(op); @@ -263,7 +263,7 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto forOp = cast(op); @@ -271,16 +271,15 @@ } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. This is // to simplify the analysis. // TODO: Consider doing sth. like isValueWritten. return true; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) return {}; @@ -288,7 +287,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); @@ -301,7 +300,7 @@ } bool isWritable(Operation *op, Value value, - const BufferizationState &state) const { + const AnalysisState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an @@ -312,7 +311,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto forOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); @@ -391,7 +390,7 @@ /// scf.for op is currently assumed to alias with the i-th iter_arg (in the /// absence of conflicts). LogicalResult verifyAnalysis(Operation *op, - const BufferizationState &state) const { + const AnalysisState &state) const { auto forOp = cast(op); auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); @@ -424,18 +423,17 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; if (isa(op->getParentOp())) @@ -444,7 +442,7 @@ } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // Yield operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. @@ -452,7 +450,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto yieldOp = cast(op); if (!isa( yieldOp->getParentOp())) diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -29,7 +29,7 @@ shape::AssumingOp> { SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // AssumingOps do not have tensor OpOperands. The yielded value can be any // SSA value that is in scope. To allow for use-def chain traversal through // AssumingOps in the analysis, the corresponding yield value is considered @@ -49,7 +49,7 @@ // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. bool isMemoryWrite(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is @@ -59,7 +59,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto assumingOp = cast(op); // Compute new result types. @@ -115,7 +115,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } }; @@ -126,25 +126,24 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { assert(isa(op->getParentOp()) && "expected that parent is an AssumingOp"); return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { // Yield operands always bufferize inplace. Otherwise, an alloc + copy // may be generated inside the block. We should not return/yield allocations // when possible. @@ -152,7 +151,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { // Op is bufferized as part of AssumingOp. return failure(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -26,28 +26,27 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {op->getResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. @@ -85,30 +84,29 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(0) /*src*/) return {op->getOpResult(0)}; return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto collapseShapeOp = cast(op); Value buffer = *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); @@ -125,23 +123,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto dimOp = cast(op); Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); @@ -154,30 +151,29 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(0) /*src*/) return {op->getOpResult(0)}; return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto expandShapeOp = cast(op); Value buffer = *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); @@ -194,30 +190,29 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(0) /*source*/) return {op->getOpResult(0)}; return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::None; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); Value srcMemref = @@ -228,7 +223,8 @@ extractSliceOp.result().getType().cast(); // If not inplaceable, alloc. - bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0)); + bool inplace = + state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); Value alloc; if (!inplace) { FailureOr allocOrFailure = @@ -264,7 +260,7 @@ // If not inplaceable, copy. if (!inplace) { // Do not copy if the copied data is never read. - if (state.isValueRead(extractSliceOp.result())) + if (state.getAnalysisState().isValueRead(extractSliceOp.result())) if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc, state.getOptions()))) return failure(); @@ -281,23 +277,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto extractOp = cast(op); Value srcMemref = *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); @@ -334,7 +329,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto fromElementsOp = cast(op); // Allocate a buffer for the result. @@ -387,7 +382,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto generateOp = cast(op); // Allocate memory. @@ -446,18 +441,17 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { assert(&opOperand == &op->getOpOperand(1) /*dest*/ && "expected dest OpOperand"); return {op->getOpResult(0)}; @@ -465,12 +459,12 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return {&op->getOpOperand(1) /*dest*/}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto insertOp = cast(op); FailureOr destMemref = state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); @@ -483,7 +477,7 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } }; @@ -494,7 +488,7 @@ /// This is one particular type of relationship between ops on tensors that /// reduce to an equivalence on buffers. This should be generalized and /// exposed as interfaces on the proper types. -static bool areEquivalentExtractSliceOps(const BufferizationState &state, +static bool areEquivalentExtractSliceOps(const AnalysisState &state, ExtractSliceOp st, InsertSliceOp sti) { if (!st || !sti) return false; @@ -508,8 +502,8 @@ /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. -static bool hasMatchingExtractSliceOp(const BufferizationState &state, - Value value, InsertSliceOp insertOp) { +static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, + InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) @@ -527,31 +521,30 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return &opOperand == &op->getOpOperand(1) /*dest*/; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(1) /*dest*/) return {op->getResult(0)}; return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, - const BufferizationState &state) const { + const AnalysisState &state) const { Operation *readingOp = uRead->getOwner(); Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -626,7 +619,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a @@ -683,23 +676,22 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto rankOp = cast(op); Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -27,27 +27,26 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return false; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); @@ -69,34 +68,33 @@ : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } - SmallVector - getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return {op->getOpResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationState &state) const { + const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { + BufferizationState &state) const { auto writeOp = cast(op); assert(writeOp.getShapedType().isa() && "only tensor types expected");