diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -566,6 +566,12 @@ FailureOr defaultGetBufferType(Value value, const BufferizationOptions &options, const DenseMap &fixedTypes); + +/// This is the default implementation of +/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other +/// places. +bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, + unsigned index); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -360,6 +360,29 @@ value, options, fixedTypes); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given region of this op is repetitive. By default + this information is queried from the `RegionBranchOpInterface`. Ops + that do not implement this inferface can override this method to + declare regions as repetitive. + + The RaW conflict detection of One-Shot Analysis is more strict inside + repetitive regions: Op dominance cannot always be used to rule out + certain potential conflicts (e.g., a conflicting write happening after + a read), because there may not be a meaningful ordering of certain ops + that are executed multiple times. This is described in more detail in + documentation of One-Shot Analysis. + }], + /*retType=*/"bool", + /*methodName=*/"isRepetitiveRegion", + /*args=*/(ins "unsigned":$index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return mlir::bufferization::detail::defaultIsRepetitiveRegion( + cast($_op.getOperation()), index); + }] + > ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -784,3 +785,13 @@ rankedTensorType.getElementType(), layout, memorySpaceAttr); } + +bool bufferization::detail::defaultIsRepetitiveRegion( + BufferizableOpInterface bufferizableOp, unsigned index) { + assert(index < bufferizableOp->getNumRegions() && "invalid region index"); + auto regionInterface = + dyn_cast(bufferizableOp.getOperation()); + if (!regionInterface) + return false; + return regionInterface.isRepetitiveRegion(index); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -351,14 +351,40 @@ return false; } +static Region * +getEnclosingRepetitiveRegion(Operation *op, + const BufferizationOptions &options) { + while (Region *region = op->getParentRegion()) { + op = region->getParentOp(); + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) + if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + } + return nullptr; +} + +static Region * +getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) { + Region *region = value.getParentRegion(); + while (region) { + Operation *op = region->getParentOp(); + if (auto bufferizableOp = options.dynCastBufferizableOp(op)) + if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber())) + return region; + region = op->getParentRegion(); + } + return nullptr; +} + /// For each given value, find the closest enclosing repetitive region. If this /// is the same region for each value, return it. Otherwise return None. /// Note: If there is no enclosing repetitive region, return nullptr. static Optional -getCommonEnclosingRepetitiveRegion(ArrayRef values) { +getCommonEnclosingRepetitiveRegion(ArrayRef values, + const BufferizationOptions &options) { if (values.empty()) return None; - Region *r = getEnclosingRepetitiveRegion(values.front()); + Region *r = getEnclosingRepetitiveRegion(values.front(), options); for (Value value : values.drop_front()) if (getEnclosingRepetitiveRegion(value) != r) return None; @@ -432,7 +458,7 @@ // Find the inner-most enclosing repetitive region of each alias. If this is // the same region for every alias, save it in `repetitiveRegionOfWrites`. Optional repetitiveRegionOfWrites = - getCommonEnclosingRepetitiveRegion(writtenAliases); + getCommonEnclosingRepetitiveRegion(writtenAliases, options); for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -497,7 +523,7 @@ bool canUseOpDominance = writtenAliases.empty() || repetitiveRegionOfWrites == - getEnclosingRepetitiveRegion(conflictingWritingOp); + getEnclosingRepetitiveRegion(conflictingWritingOp, options); // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -48,15 +48,14 @@ AnalysisState state(options); // Look for repetitive ops (loops). - op->walk([&](RegionBranchOpInterface regionBranchOp) { - // Skip non-bufferizable ops. - auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp); - if (!bufferizableOp) + op->walk([&](BufferizableOpInterface bufferizableOp) { + // Skip filtered ops. + if (!options.isOpAllowed(bufferizableOp.getOperation())) return WalkResult::advance(); - // Find all operands that are also used inside of a repetitve region of this - // op. - for (OpOperand &opOperand : regionBranchOp->getOpOperands()) { + // Find all operands that are also used inside of a repetitive region of + // this op. + for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { Value operand = opOperand.get(); // Skip non-tensor operands. if (!operand.getType().isa()) @@ -69,11 +68,11 @@ SmallVector usesInsideRegion; for (OpOperand &use : operand.getUses()) { Operation *owner = use.getOwner(); - if (!regionBranchOp->isProperAncestor(owner)) + if (!bufferizableOp->isProperAncestor(owner)) continue; - for (Region &r : regionBranchOp->getRegions()) { + for (Region &r : bufferizableOp->getRegions()) { if (r.findAncestorOpInRegion(*owner) && - regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) { + bufferizableOp.isRepetitiveRegion(r.getRegionNumber())) { usesInsideRegion.push_back(&use); break; } @@ -84,9 +83,9 @@ continue; // Insert a tensor copy and replace all uses inside of repetitive regions. - rewriter.setInsertionPoint(regionBranchOp); + rewriter.setInsertionPoint(bufferizableOp); auto tensorCopy = rewriter.create( - regionBranchOp->getLoc(), operand.getType().cast(), + bufferizableOp->getLoc(), operand.getType().cast(), /*dynamicSizes=*/ValueRange(), /*copy=*/operand, /*memory_space=*/IntegerAttr()); for (OpOperand *use : usesInsideRegion) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9088,6 +9088,7 @@ ":BufferizableOpInterfaceIncGen", ":BufferizationBaseIncGen", ":BufferizationOpsIncGen", + ":ControlFlowInterfaces", ":CopyOpInterface", ":FuncDialect", ":IR",