diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -522,6 +522,19 @@ /// owner of the block. In case of an OpResult that is the defining op. Operation *getOwnerOfValue(Value value); +/// Return the closest enclosing repetitive region around the given op. +Region *getEnclosingRepetitiveRegion(Operation *op, + const BufferizationOptions &options); + +/// Return the closest enclosing repetitive region around the place where the +/// given value is defined. +Region *getEnclosingRepetitiveRegion(Value value, + const BufferizationOptions &options); + +/// Return the closest enclosing repetitive region around the given block. +Region *getEnclosingRepetitiveRegion(Block *block, + const BufferizationOptions &options); + namespace detail { /// This is the default implementation of /// BufferizableOpInterface::getBufferType. Should not be called from other diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -41,6 +41,39 @@ using namespace mlir; using namespace bufferization; +Region *bufferization::getEnclosingRepetitiveRegion( + Operation *op, const BufferizationOptions &options) { + if (!op->getBlock()) + return nullptr; + return getEnclosingRepetitiveRegion(op->getBlock(), options); +} + +Region *bufferization::getEnclosingRepetitiveRegion( + Value value, const BufferizationOptions &options) { + Region *region = value.getParentRegion(); + while (region) { + Operation *op = region->getParentOp(); + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) + if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + region = op->getParentRegion(); + } + return nullptr; +} + +Region *bufferization::getEnclosingRepetitiveRegion( + Block *block, const BufferizationOptions &options) { + Region *region = block->getParent(); + Operation *op = nullptr; + do { + op = region->getParentOp(); + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) + if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + } while ((region = op->getParentRegion())); + return nullptr; +} + Operation *bufferization::getOwnerOfValue(Value value) { if (auto opResult = value.dyn_cast()) return opResult.getDefiningOp(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -355,31 +355,6 @@ return false; } -static Region * -getEnclosingRepetitiveRegion(Operation *op, - const BufferizationOptions &options) { - while (Region *region = op->getParentRegion()) { - op = region->getParentOp(); - if (auto bufferizableOp = options.dynCastBufferizableOp(op)) - if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) - return region; - } - return nullptr; -} - -static Region * -getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) { - Region *region = value.getParentRegion(); - while (region) { - Operation *op = region->getParentOp(); - if (auto bufferizableOp = options.dynCastBufferizableOp(op)) - if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) - return region; - region = op->getParentRegion(); - } - return nullptr; -} - /// Return `true` if the given tensor value is a memory write. Most values are /// tensor writes, but ops that define a tensor SSA value without specifying its /// contents (e.g., alloc_tensor) are not.