diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -12,6 +12,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/ADT/SetVector.h" #include @@ -546,6 +547,19 @@ TypeID getType() const { return type; } + /// Return the closest enclosing repetitive region around the given op. + Region *getEnclosingRepetitiveRegion(Operation *op, + const BufferizationOptions &options); + + /// Return the closest enclosing repetitive region around the place where the + /// given value is defined. + Region *getEnclosingRepetitiveRegion(Value value, + const BufferizationOptions &options); + + /// Return the closest enclosing repetitive region around the given block. + Region *getEnclosingRepetitiveRegion(Block *block, + const BufferizationOptions &options); + protected: AnalysisState(const BufferizationOptions &options, TypeID type); @@ -555,6 +569,10 @@ /// The type of analysis. TypeID type; + + /// Cache containing closest ancestor repetitive Region. + DenseMap, Region *> + enclosingRepetitiveRegionCache; }; /// Create an AllocTensorOp for the given shaped value (memref or tensor). @@ -652,19 +670,6 @@ /// owner of the block. In case of an OpResult that is the defining op. Operation *getOwnerOfValue(Value value); -/// Return the closest enclosing repetitive region around the given op. -Region *getEnclosingRepetitiveRegion(Operation *op, - const BufferizationOptions &options); - -/// Return the closest enclosing repetitive region around the place where the -/// given value is defined. -Region *getEnclosingRepetitiveRegion(Value value, - const BufferizationOptions &options); - -/// Return the closest enclosing repetitive region around the given block. -Region *getEnclosingRepetitiveRegion(Block *block, - const BufferizationOptions &options); - /// Assuming that the given region is repetitive, find the next enclosing /// repetitive region. Region *getNextEnclosingRepetitiveRegion(Region *region, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -50,32 +50,56 @@ return false; } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Operation *op, const BufferizationOptions &options) { if (!op->getBlock()) return nullptr; - return getEnclosingRepetitiveRegion(op->getBlock(), options); + if (auto *reg = enclosingRepetitiveRegionCache[op]) + return reg; + return enclosingRepetitiveRegionCache[op] = + getEnclosingRepetitiveRegion(op->getBlock(), options); } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Value value, const BufferizationOptions &options) { + if (auto *reg = enclosingRepetitiveRegionCache[value]) + return reg; + Region *region = value.getParentRegion(); + // Collect all visited regions since we only know the repetitive region we + // want to map it to later on + SmallVector visitedRegions; while (region) { - if (isRepetitiveRegion(region, options)) + visitedRegions.push_back(region); + if (isRepetitiveRegion(region, options)) { + enclosingRepetitiveRegionCache[value] = region; + for (Region *r : visitedRegions) + enclosingRepetitiveRegionCache[r] = region; return region; + } region = region->getParentRegion(); } return nullptr; } -Region *bufferization::getEnclosingRepetitiveRegion( +Region *AnalysisState::getEnclosingRepetitiveRegion( Block *block, const BufferizationOptions &options) { + if (auto *reg = enclosingRepetitiveRegionCache[block]) + return reg; + Region *region = block->getParent(); Operation *op = nullptr; + // Collect all visited regions since we only know the repetitive region we + // want to map it to later on + SmallVector visitedRegions; do { op = region->getParentOp(); - if (isRepetitiveRegion(region, options)) + if (isRepetitiveRegion(region, options)) { + enclosingRepetitiveRegionCache[block] = region; + for (Region *r : visitedRegions) + enclosingRepetitiveRegionCache[r] = region; return region; + } } while ((region = op->getParentRegion())); return nullptr; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -383,13 +383,14 @@ /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.) /// -bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, - const SetVector &definitions, - const AnalysisState &state) { +static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, + const SetVector &definitions, + AnalysisState &state) { const BufferizationOptions &options = state.getOptions(); for (Value def : definitions) { - Region *rRead = getEnclosingRepetitiveRegion(uRead->getOwner(), options); - Region *rDef = getEnclosingRepetitiveRegion(def, options); + Region *rRead = + state.getEnclosingRepetitiveRegion(uRead->getOwner(), options); + Region *rDef = state.getEnclosingRepetitiveRegion(def, options); // READ and DEF are in the same repetitive region. `happensBefore` can be // used to rule out RaW conflicts due to op ordering.