diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -49,11 +49,10 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Set the inPlace bufferization spec to true. - /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpOperand &operand, AnalysisState &state); + /// Mark the given OpOperand as in-place. + void bufferizeInPlace(OpOperand &operand); - /// Set the inPlace bufferization spec to false. + /// Mark the given OpOperand as out-of-place. void bufferizeOutOfPlace(OpOperand &operand); /// Return true if `v1` and `v2` may bufferize to aliasing buffers. @@ -80,9 +79,6 @@ /// Apply `fun` to all aliases of `v`. void applyOnAliases(Value v, function_ref fun) const; - /// Mark a value as in-place bufferized. - void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); } - /// Return `true` if a value was marked as in-place bufferized. bool isInPlace(OpOperand &opOperand) const; @@ -136,6 +132,13 @@ return base->getType() == TypeID::get(); } + /// Mark the given OpOperand as in-place. Also merge result's and operand's + /// aliasing sets. + void bufferizeInPlace(OpOperand &operand); + + /// Mark the given OpOperand as out-of-place. + void bufferizeOutOfPlace(OpOperand &operand); + /// Return a reference to the BufferizationAliasInfo. BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } @@ -213,6 +216,12 @@ /// Provides read-only access to the parent OneShotAnalysisState object. const OneShotAnalysisState &getAnalysisState() const { return state; } + /// Notify this extension of an in-place bufferization decision. + virtual void notifyBufferizeInPlace(OpOperand &operand) {} + + /// Notify this extension of an out-of-place bufferization decision. + virtual void notifyBufferizeOutOfPlace(OpOperand &operand) {} + private: /// Back-reference to the state that is being extended. OneShotAnalysisState &state; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -138,15 +138,10 @@ return inplaceBufferized.contains(&operand); } -/// Set the inPlace bufferization spec to true. -void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, - AnalysisState &state) { - markInPlace(operand); - for (OpResult result : state.getAliasingOpResult(operand)) - aliasInfo.unionSets(result, operand.get()); +void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand) { + inplaceBufferized.insert(&operand); } -/// Set the inPlace bufferization spec to false. void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { assert(!inplaceBufferized.contains(&operand) && "OpOperand was already decided to bufferize inplace"); @@ -202,13 +197,27 @@ for (OpResult opResult : bufferizableOp.getAliasingOpResult(opOperand, *this)) aliasInfo.unionAliasSets(opOperand.get(), opResult); - aliasInfo.markInPlace(opOperand); + aliasInfo.bufferizeInPlace(opOperand); } } return WalkResult::advance(); }); } +void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) { + aliasInfo.bufferizeInPlace(operand); + for (OpResult result : getAliasingOpResult(operand)) + aliasInfo.unionAliasSets(result, operand.get()); + for (auto &it : extensions) + it.getSecond()->notifyBufferizeInPlace(operand); +} + +void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) { + aliasInfo.bufferizeOutOfPlace(operand); + for (auto &it : extensions) + it.getSecond()->notifyBufferizeOutOfPlace(operand); +} + bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { return aliasInfo.isInPlace(opOperand); } @@ -837,9 +846,9 @@ wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); if (foundInterference) - aliasInfo.bufferizeOutOfPlace(operand); + state.bufferizeOutOfPlace(operand); else - aliasInfo.bufferizeInPlace(operand, state); + state.bufferizeInPlace(operand); return success(); }