diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -237,6 +237,38 @@ return success(); } +/// Return `true` if region `r` is reachable from region `begin` according to +/// the RegionBranchOpInterface (by taking a branch). +static bool isRegionReachable(Region *begin, Region *r) { + auto op = cast(begin->getParentOp()); + SmallVector visited(op->getNumRegions(), false); + visited[begin->getRegionNumber()] = true; + + // Retrieve all successors of the region and enqueue them in the worklist. + SmallVector worklist; + auto enqueueAllSuccessors = [&](unsigned index) { + SmallVector successors; + op.getSuccessorRegions(index, successors); + for (RegionSuccessor successor : successors) + if (!successor.isParent()) + worklist.push_back(successor.getSuccessor()->getRegionNumber()); + }; + enqueueAllSuccessors(begin->getRegionNumber()); + + // Process all regions in the worklist via DFS. + while (!worklist.empty()) { + unsigned nextRegion = worklist.pop_back_val(); + if (nextRegion == r->getRegionNumber()) + return true; + if (visited[nextRegion]) + continue; + visited[nextRegion] = true; + enqueueAllSuccessors(nextRegion); + } + + return false; +} + /// Return `true` if `a` and `b` are in mutually exclusive regions. /// /// 1. Find the first common of `a` and `b` (ancestor) that implements @@ -274,33 +306,9 @@ } assert(regionA && regionB && "could not find region of op"); - // Helper function that checks if region `r` is reachable from region - // `begin`. - std::function isRegionReachable = - [&](Region *begin, Region *r) { - if (begin == r) - return true; - if (begin == nullptr) - return false; - // Compute index of region. - int64_t beginIndex = -1; - for (const auto &it : llvm::enumerate(branchOp->getRegions())) - if (&it.value() == begin) - beginIndex = it.index(); - assert(beginIndex != -1 && "could not find region in op"); - // Retrieve all successors of the region. - SmallVector successors; - branchOp.getSuccessorRegions(beginIndex, successors); - // Call function recursively on all successors. - for (RegionSuccessor successor : successors) - if (isRegionReachable(successor.getSuccessor(), r)) - return true; - return false; - }; - - // `a` and `b` are in mutually exclusive regions if neither region is - // reachable from the other region. - return !isRegionReachable(regionA, regionB) && + // `a` and `b` are in mutually exclusive regions if both regions are + // distinct and neither region is reachable from the other region. + return regionA != regionB && !isRegionReachable(regionA, regionB) && !isRegionReachable(regionB, regionA); } @@ -310,32 +318,8 @@ } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { - SmallVector visited(getOperation()->getNumRegions(), false); - visited[index] = true; - - // Retrieve all successors of the region and enqueue them in the worklist. - SmallVector worklist; - auto enqueueAllSuccessors = [&](unsigned index) { - SmallVector successors; - this->getSuccessorRegions(index, successors); - for (RegionSuccessor successor : successors) - if (!successor.isParent()) - worklist.push_back(successor.getSuccessor()->getRegionNumber()); - }; - enqueueAllSuccessors(index); - - // Process all regions in the worklist via DFS. - while (!worklist.empty()) { - unsigned nextRegion = worklist.pop_back_val(); - if (nextRegion == index) - return true; - if (visited[nextRegion]) - continue; - visited[nextRegion] = true; - enqueueAllSuccessors(nextRegion); - } - - return false; + Region *region = &getOperation()->getRegion(index); + return isRegionReachable(region, region); } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp --- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp @@ -65,6 +65,27 @@ } }; +/// Each region branches back it itself or the parent. +struct DoubleLoopRegionsOp + : public Op { + using Op::Op; + + static ArrayRef getAttributeNames() { return {}; } + + static StringRef getOperationName() { + return "cftest.double_loop_regions_op"; + } + + void getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + if (index.hasValue()) { + regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor(&getOperation()->getRegion(*index))); + } + } +}; + /// Regions are executed sequentially. struct SequentialRegionsOp : public Op { @@ -89,7 +110,7 @@ explicit CFTestDialect(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx, TypeID::get()) { addOperations(); + DoubleLoopRegionsOp, SequentialRegionsOp>(); } static StringRef getDialectNamespace() { return "cftest"; } }; @@ -115,6 +136,27 @@ EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1)); } +TEST(RegionBranchOpInterface, MutuallyExclusiveOps2) { + const char *ir = R"MLIR( +"cftest.double_loop_regions_op"() ( + {"cftest.dummy_op"() : () -> ()}, // op1 + {"cftest.dummy_op"() : () -> ()} // op2 + ) : () -> () + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningOpRef module = parseSourceString(ir, &ctx); + Operation *testOp = &module->getBody()->getOperations().front(); + Operation *op1 = &testOp->getRegion(0).front().front(); + Operation *op2 = &testOp->getRegion(1).front().front(); + + EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2)); + EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1)); +} + TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) { const char *ir = R"MLIR( "cftest.sequential_regions_op"() (