diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -365,6 +365,77 @@ return false; } +/// Return `true` if `a` and `b` are in mutually exclusive regions. +/// +/// 1. Find the first common of `a` and `b` (ancestor) that implements +/// RegionBranchOpInterface. +/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are +/// contained. +/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are +/// mutually exclusive if they are not reachable from each other as per +/// RegionBranchOpInterface::getSuccessorRegions. +static bool insideMutuallyExclusiveRegions(Operation *a, Operation *b) { + assert(a && "expected non-empty operation"); + assert(b && "expected non-empty operation"); + + auto branchOp = a->getParentOfType(); + while (branchOp) { + // Check if b is inside branchOp. (We already know that a is.) + if (branchOp->isProperAncestor(b)) { + // b is contained in branchOp. Retrieve the regions in which `a` and `b` + // are contained. + Region *regionA = nullptr, *regionB = nullptr; + for (Region &r : branchOp->getRegions()) { + if (r.findAncestorOpInRegion(*a)) { + assert(!regionA && "already found a region for a"); + regionA = &r; + } + if (r.findAncestorOpInRegion(*b)) { + assert(!regionB && "already found a region for b"); + regionB = &r; + } + } + assert(regionA && regionB && "could not find region of op"); + + // Helper function that checks if region `r` is reachable from region + // `begin`. + std::function isRegionReachable = + [&](Region *begin, Region *r) { + if (begin == r) + return true; + if (begin == nullptr) + return false; + // Compute index of region. + int64_t beginIndex = -1; + for (const auto &it : llvm::enumerate(branchOp->getRegions())) + if (&it.value() == begin) + beginIndex = it.index(); + assert(beginIndex != -1 && "could not find region in op"); + // Retrieve all successors of the region. + SmallVector successors; + branchOp.getSuccessorRegions(beginIndex, successors); + // Call function recursively on all successors. + for (RegionSuccessor successor : successors) + if (isRegionReachable(successor.getSuccessor(), r)) + return true; + return false; + }; + + // `a` and `b` are in mutually exclusive regions if neither region is + // reachable from the other region. + return !isRegionReachable(regionA, regionB) && + !isRegionReachable(regionB, regionA); + } + + // Check next enclosing RegionBranchOpInterface. + branchOp = branchOp->getParentOfType(); + } + + // Could not find a common RegionBranchOpInterface among a's and b's + // ancestors. + return false; +} + /// Given sets of uses and writes, return true if there is a RaW conflict under /// the assumption that all given reads/writes alias the same buffer and that /// all given writes bufferize inplace. @@ -430,9 +501,8 @@ aliasInfo)) continue; - // Special rules for branches. - // TODO: Use an interface. - if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp)) + // Ops are not conflicting if they are in mutually exclusive regions. + if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");