diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -12,6 +12,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/ADT/SetVector.h" #include @@ -546,6 +547,21 @@ TypeID getType() const { return type; } + /// Return the closest enclosing repetitive region around the given op. + Region *getEnclosingRepetitiveRegion(Operation *op, + const BufferizationOptions &options); + + /// Return the closest enclosing repetitive region around the place where the + /// given value is defined. + Region *getEnclosingRepetitiveRegion(Value value, + const BufferizationOptions &options); + + /// Return the closest enclosing repetitive region around the given block. + Region *getEnclosingRepetitiveRegion(Block *block, + const BufferizationOptions &options); + + virtual void resetCache(); + protected: AnalysisState(const BufferizationOptions &options, TypeID type); @@ -555,6 +571,10 @@ /// The type of analysis. TypeID type; + + /// Cache containing closest ancestor repetitive Region. + DenseMap, Region *> + enclosingRepetitiveRegionCache; }; /// Create an AllocTensorOp for the given shaped value (memref or tensor). @@ -652,19 +672,6 @@ /// owner of the block. In case of an OpResult that is the defining op. Operation *getOwnerOfValue(Value value); -/// Return the closest enclosing repetitive region around the given op. -Region *getEnclosingRepetitiveRegion(Operation *op, - const BufferizationOptions &options); - -/// Return the closest enclosing repetitive region around the place where the -/// given value is defined. -Region *getEnclosingRepetitiveRegion(Value value, - const BufferizationOptions &options); - -/// Return the closest enclosing repetitive region around the given block. -Region *getEnclosingRepetitiveRegion(Block *block, - const BufferizationOptions &options); - /// Assuming that the given region is repetitive, find the next enclosing /// repetitive region. Region *getNextEnclosingRepetitiveRegion(Region *region, diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -130,7 +130,7 @@ const SetVector &findDefinitionsCached(Value value); /// Reset cached data structures. - void resetCache(); + void resetCache() override; /// Union the alias sets of `v1` and `v2`. void unionAliasSets(Value v1, Value v2); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -50,36 +50,64 @@ return false; } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Operation *op, const BufferizationOptions &options) { if (!op->getBlock()) return nullptr; - return getEnclosingRepetitiveRegion(op->getBlock(), options); + if (auto iter = enclosingRepetitiveRegionCache.find_as(op); + iter != enclosingRepetitiveRegionCache.end()) + return iter->second; + return enclosingRepetitiveRegionCache[op] = + getEnclosingRepetitiveRegion(op->getBlock(), options); } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Value value, const BufferizationOptions &options) { + if (auto iter = enclosingRepetitiveRegionCache.find_as(value); + iter != enclosingRepetitiveRegionCache.end()) + return iter->second; + Region *region = value.getParentRegion(); + // Collect all visited regions since we only know the repetitive region we + // want to map it to later on + SmallVector visitedRegions; while (region) { + visitedRegions.push_back(region); if (isRepetitiveRegion(region, options)) - return region; + break; region = region->getParentRegion(); } - return nullptr; + enclosingRepetitiveRegionCache[value] = region; + for (Region *r : visitedRegions) + enclosingRepetitiveRegionCache[r] = region; + return region; } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Block *block, const BufferizationOptions &options) { + if (auto iter = enclosingRepetitiveRegionCache.find_as(block); + iter != enclosingRepetitiveRegionCache.end()) + return iter->second; + Region *region = block->getParent(); Operation *op = nullptr; + // Collect all visited regions since we only know the repetitive region we + // want to map it to later on + SmallVector visitedRegions; do { op = region->getParentOp(); if (isRepetitiveRegion(region, options)) - return region; + break; } while ((region = op->getParentRegion())); - return nullptr; + + enclosingRepetitiveRegionCache[block] = region; + for (Region *r : visitedRegions) + enclosingRepetitiveRegionCache[r] = region; + return region; } +void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); } + Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -383,13 +383,14 @@ /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.) /// -bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, - const SetVector &definitions, - const AnalysisState &state) { +static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, + const SetVector &definitions, + AnalysisState &state) { const BufferizationOptions &options = state.getOptions(); for (Value def : definitions) { - Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options); - Region *rDef = getEnclosingRepetitiveRegion(def, options); + Region *rRead = + state.getEnclosingRepetitiveRegion(uRead->getOwner(), options); + Region *rDef = state.getEnclosingRepetitiveRegion(def, options); // READ and DEF are in the same repetitive region. `happensBefore` can be // used to rule out RaW conflicts due to op ordering. @@ -782,7 +783,10 @@ return cachedDefinitions[value]; } -void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); } +void OneShotAnalysisState::resetCache() { + AnalysisState::resetCache(); + cachedDefinitions.clear(); +} /// Determine if `operand` can be bufferized in-place. static LogicalResult