diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -34,12 +34,20 @@ /// Recalculate the dominance info. void recalculate(Operation *op); + /// Finds the nearest common dominator block for the two given blocks a + /// and b. If no common dominator can be found, this function will return + /// nullptr. + Block *findNearestCommonDominator(Block *a, Block *b) const; + /// Get the root dominance node of the given region. DominanceInfoNode *getRootNode(Region *region) { assert(dominanceInfos.count(region) != 0); return dominanceInfos[region]->getRootNode(); } + /// Return the dominance node from the Region containing block A. + DominanceInfoNode *getNode(Block *a); + protected: using super = DominanceInfoBase; @@ -82,9 +90,6 @@ return super::properlyDominates(a, b); } - /// Return the dominance node from the Region containing block A. - DominanceInfoNode *getNode(Block *a); - /// Update the internal DFS numbers for the dominance nodes. void updateDFSNumbers(); }; diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -13,6 +13,7 @@ #include "mlir/Analysis/Dominance.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; @@ -43,6 +44,99 @@ }); } +/// Walks up the list of containers of the given block and calls the +/// user-defined traversal function for every pair of a region and block that +/// could be found during traversal. If the user-defined function returns true +/// for a given pair, traverseAncestors will return the current block. Nullptr +/// otherwise. +template +Block *traverseAncestors(Block *block, const FuncT &func) { + // Invoke the user-defined traversal function in the beginning for the current + // block. + if (func(block)) + return block; + + Region *region = block->getParent(); + while (region) { + Operation *ancestor = region->getParentOp(); + // If we have reached to top... return. + if (!ancestor || !(block = ancestor->getBlock())) + break; + + // Update the nested region using the new ancestor block. + region = block->getParent(); + + // Invoke the user-defined traversal function and check whether we can + // already return. + if (func(block)) + return block; + } + return nullptr; +} + +/// Tries to update the given block references to live in the same region by +/// exploring the relationship of both blocks with respect to their regions. +static bool tryGetBlocksInSameRegion(Block *&a, Block *&b) { + // If both block do not live in the same region, we will have to check their + // parent operations. + if (a->getParent() == b->getParent()) + return true; + + // Iterate over all ancestors of a and insert them into the map. This allows + // for efficient lookups to find a commonly shared region. + llvm::SmallDenseMap ancestors; + traverseAncestors(a, [&](Block *block) { + ancestors[block->getParent()] = block; + return false; + }); + + // Try to find a common ancestor starting with regionB. + b = traverseAncestors( + b, [&](Block *block) { return ancestors.count(block->getParent()) > 0; }); + + // If there is no match, we will not be able to find a common dominator since + // both regions do not share a common parent region. + if (!b) + return false; + + // We have found a common parent region. Update block a to refer to this + // region. + auto it = ancestors.find(b->getParent()); + assert(it != ancestors.end()); + a = it->second; + return true; +} + +template +Block * +DominanceInfoBase::findNearestCommonDominator(Block *a, + Block *b) const { + // If either a or b are null, then conservatively return nullptr. + if (!a || !b) + return nullptr; + + // Try to find blocks that are in the same region. + if (!tryGetBlocksInSameRegion(a, b)) + return nullptr; + + // Get and verify dominance information of the common parent region. + Region *parentRegion = a->getParent(); + auto infoAIt = dominanceInfos.find(parentRegion); + if (infoAIt == dominanceInfos.end()) + return nullptr; + + // Since the blocks live in the same region, we can rely on already + // existing dominance functionality. + return infoAIt->second->findNearestCommonDominator(a, b); +} + +template +DominanceInfoNode *DominanceInfoBase::getNode(Block *a) { + auto *region = a->getParent(); + assert(dominanceInfos.count(region) != 0); + return dominanceInfos[region]->getNode(a); +} + /// Return true if the specified block A properly dominates block B. template bool DominanceInfoBase::properlyDominates(Block *a, Block *b) { @@ -57,21 +151,17 @@ // If both blocks are not in the same region, 'a' properly dominates 'b' if // 'b' is defined in an operation region that (recursively) ends up being // dominated by 'a'. Walk up the list of containers enclosing B. - auto *regionA = a->getParent(), *regionB = b->getParent(); - if (regionA != regionB) { - Operation *bAncestor; - do { - bAncestor = regionB->getParentOp(); - // If 'bAncestor' is the top level region, then 'a' is a block that post - // dominates 'b'. - if (!bAncestor || !bAncestor->getBlock()) - return IsPostDom; - - regionB = bAncestor->getBlock()->getParent(); - } while (regionA != regionB); + auto *regionA = a->getParent(); + if (regionA != b->getParent()) { + b = traverseAncestors( + b, [&](Block *block) { return block->getParent() == regionA; }); + + // If we could not find a valid block b then it is either a not a dominator + // or a post dominator. + if (!b) + return IsPostDom; // Check to see if the ancestor of 'b' is the same block as 'a'. - b = bAncestor->getBlock(); if (a == b) return true; } @@ -132,12 +222,6 @@ return dominates(a.cast().getOwner(), b->getBlock()); } -DominanceInfoNode *DominanceInfo::getNode(Block *a) { - auto *region = a->getParent(); - assert(dominanceInfos.count(region) != 0); - return dominanceInfos[region]->getNode(a); -} - void DominanceInfo::updateDFSNumbers() { for (auto &iter : dominanceInfos) iter.second->updateDFSNumbers(); diff --git a/mlir/test/Analysis/test-dominance.mlir b/mlir/test/Analysis/test-dominance.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-dominance.mlir @@ -0,0 +1,207 @@ +// RUN: mlir-opt %s -test-print-dominance -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: Testing : func_condBranch +func @func_condBranch(%cond : i1) { + cond_br %cond, ^bb1, ^bb2 +^bb1: + br ^exit +^bb2: + br ^exit +^exit: + return +} +// CHECK-LABEL: --- DominanceInfo --- +// CHECK-NEXT: Nearest(0, 0) = 0 +// CHECK-NEXT: Nearest(0, 1) = 0 +// CHECK-NEXT: Nearest(0, 2) = 0 +// CHECK-NEXT: Nearest(0, 3) = 0 +// CHECK: Nearest(1, 0) = 0 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 0 +// CHECK-NEXT: Nearest(1, 3) = 0 +// CHECK: Nearest(2, 0) = 0 +// CHECK-NEXT: Nearest(2, 1) = 0 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 0 +// CHECK: Nearest(3, 0) = 0 +// CHECK-NEXT: Nearest(3, 1) = 0 +// CHECK-NEXT: Nearest(3, 2) = 0 +// CHECK-NEXT: Nearest(3, 3) = 3 +// CHECK-LABEL: --- PostDominanceInfo --- +// CHECK-NEXT: Nearest(0, 0) = 0 +// CHECK-NEXT: Nearest(0, 1) = 3 +// CHECK-NEXT: Nearest(0, 2) = 3 +// CHECK-NEXT: Nearest(0, 3) = 3 +// CHECK: Nearest(1, 0) = 3 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 3 +// CHECK-NEXT: Nearest(1, 3) = 3 +// CHECK: Nearest(2, 0) = 3 +// CHECK-NEXT: Nearest(2, 1) = 3 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 3 +// CHECK: Nearest(3, 0) = 3 +// CHECK-NEXT: Nearest(3, 1) = 3 +// CHECK-NEXT: Nearest(3, 2) = 3 +// CHECK-NEXT: Nearest(3, 3) = 3 + +// ----- + +// CHECK-LABEL: Testing : func_loop +func @func_loop(%arg0 : i32, %arg1 : i32) { + br ^loopHeader(%arg0 : i32) +^loopHeader(%counter : i32): + %lessThan = cmpi "slt", %counter, %arg1 : i32 + cond_br %lessThan, ^loopBody, ^exit +^loopBody: + %const0 = constant 1 : i32 + %inc = addi %counter, %const0 : i32 + br ^loopHeader(%inc : i32) +^exit: + return +} +// CHECK-LABEL: --- DominanceInfo --- +// CHECK: Nearest(1, 0) = 0 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 1 +// CHECK-NEXT: Nearest(1, 3) = 1 +// CHECK: Nearest(2, 0) = 0 +// CHECK-NEXT: Nearest(2, 1) = 1 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 1 +// CHECK: Nearest(3, 0) = 0 +// CHECK-NEXT: Nearest(3, 1) = 1 +// CHECK-NEXT: Nearest(3, 2) = 1 +// CHECK-NEXT: Nearest(3, 3) = 3 +// CHECK-LABEL: --- PostDominanceInfo --- +// CHECK: Nearest(1, 0) = 1 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 1 +// CHECK-NEXT: Nearest(1, 3) = 3 +// CHECK: Nearest(2, 0) = 1 +// CHECK-NEXT: Nearest(2, 1) = 1 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 3 +// CHECK: Nearest(3, 0) = 3 +// CHECK-NEXT: Nearest(3, 1) = 3 +// CHECK-NEXT: Nearest(3, 2) = 3 +// CHECK-NEXT: Nearest(3, 3) = 3 + +// ----- + +// CHECK-LABEL: Testing : nested_region +func @nested_region(%arg0 : index, %arg1 : index, %arg2 : index) { + loop.for %arg3 = %arg0 to %arg1 step %arg2 { } + return +} + +// CHECK-LABEL: --- DominanceInfo --- +// CHECK-NEXT: Nearest(0, 0) = 0 +// CHECK-NEXT: Nearest(0, 1) = 1 +// CHECK: Nearest(1, 0) = 1 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-LABEL: --- PostDominanceInfo --- +// CHECK-NEXT: Nearest(0, 0) = 0 +// CHECK-NEXT: Nearest(0, 1) = 1 +// CHECK: Nearest(1, 0) = 1 +// CHECK-NEXT: Nearest(1, 1) = 1 + +// ----- + +// CHECK-LABEL: Testing : nested_region2 +func @nested_region2(%arg0 : index, %arg1 : index, %arg2 : index) { + loop.for %arg3 = %arg0 to %arg1 step %arg2 { + loop.for %arg4 = %arg0 to %arg1 step %arg2 { + loop.for %arg5 = %arg0 to %arg1 step %arg2 { } + } + } + return +} +// CHECK-LABEL: --- DominanceInfo --- +// CHECK: Nearest(1, 0) = 1 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 2 +// CHECK-NEXT: Nearest(1, 3) = 3 +// CHECK: Nearest(2, 0) = 2 +// CHECK-NEXT: Nearest(2, 1) = 2 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 3 +// CHECK: Nearest(3, 0) = 3 +// CHECK-NEXT: Nearest(3, 1) = 3 +// CHECK-NEXT: Nearest(3, 2) = 3 +// CHECK-NEXT: Nearest(3, 3) = 3 +// CHECK-LABEL: --- PostDominanceInfo --- +// CHECK-NEXT: Nearest(0, 0) = 0 +// CHECK-NEXT: Nearest(0, 1) = 1 +// CHECK-NEXT: Nearest(0, 2) = 2 +// CHECK-NEXT: Nearest(0, 3) = 3 +// CHECK: Nearest(1, 0) = 1 +// CHECK-NEXT: Nearest(1, 1) = 1 +// CHECK-NEXT: Nearest(1, 2) = 2 +// CHECK-NEXT: Nearest(1, 3) = 3 +// CHECK: Nearest(2, 0) = 2 +// CHECK-NEXT: Nearest(2, 1) = 2 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 3 + +// ----- + +// CHECK-LABEL: Testing : func_loop_nested_region +func @func_loop_nested_region( + %arg0 : i32, + %arg1 : i32, + %arg2 : index, + %arg3 : index, + %arg4 : index) { + br ^loopHeader(%arg0 : i32) +^loopHeader(%counter : i32): + %lessThan = cmpi "slt", %counter, %arg1 : i32 + cond_br %lessThan, ^loopBody, ^exit +^loopBody: + %const0 = constant 1 : i32 + %inc = addi %counter, %const0 : i32 + loop.for %arg5 = %arg2 to %arg3 step %arg4 { + loop.for %arg6 = %arg2 to %arg3 step %arg4 { } + } + br ^loopHeader(%inc : i32) +^exit: + return +} +// CHECK-LABEL: --- DominanceInfo --- +// CHECK: Nearest(2, 0) = 0 +// CHECK-NEXT: Nearest(2, 1) = 1 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 2 +// CHECK-NEXT: Nearest(2, 4) = 2 +// CHECK-NEXT: Nearest(2, 5) = 1 +// CHECK: Nearest(3, 0) = 0 +// CHECK-NEXT: Nearest(3, 1) = 1 +// CHECK-NEXT: Nearest(3, 2) = 2 +// CHECK-NEXT: Nearest(3, 3) = 3 +// CHECK-NEXT: Nearest(3, 4) = 4 +// CHECK-NEXT: Nearest(3, 5) = 1 +// CHECK: Nearest(4, 0) = 0 +// CHECK-NEXT: Nearest(4, 1) = 1 +// CHECK-NEXT: Nearest(4, 2) = 2 +// CHECK-NEXT: Nearest(4, 3) = 4 +// CHECK-NEXT: Nearest(4, 4) = 4 +// CHECK-NEXT: Nearest(4, 5) = 1 +// CHECK-LABEL: --- PostDominanceInfo --- +// CHECK: Nearest(2, 0) = 1 +// CHECK-NEXT: Nearest(2, 1) = 1 +// CHECK-NEXT: Nearest(2, 2) = 2 +// CHECK-NEXT: Nearest(2, 3) = 2 +// CHECK-NEXT: Nearest(2, 4) = 2 +// CHECK-NEXT: Nearest(2, 5) = 5 +// CHECK: Nearest(3, 0) = 1 +// CHECK-NEXT: Nearest(3, 1) = 1 +// CHECK-NEXT: Nearest(3, 2) = 2 +// CHECK-NEXT: Nearest(3, 3) = 3 +// CHECK-NEXT: Nearest(3, 4) = 4 +// CHECK-NEXT: Nearest(3, 5) = 5 +// CHECK: Nearest(4, 0) = 1 +// CHECK-NEXT: Nearest(4, 1) = 1 +// CHECK-NEXT: Nearest(4, 2) = 2 +// CHECK-NEXT: Nearest(4, 3) = 4 +// CHECK-NEXT: Nearest(4, 4) = 4 +// CHECK-NEXT: Nearest(4, 5) = 5 \ No newline at end of file diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ TestAllReduceLowering.cpp TestCallGraph.cpp TestConstantFold.cpp + TestDominance.cpp TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp diff --git a/mlir/test/lib/Transforms/TestDominance.cpp b/mlir/test/lib/Transforms/TestDominance.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDominance.cpp @@ -0,0 +1,90 @@ +//===- TestDominance.cpp - Test dominance construction and information +//-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for constructing and resolving dominance +// information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Dominance.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +/// Helper class to print dominance information. +class DominanceTest { +public: + /// Constructs a new test instance using the given operation. + DominanceTest(Operation *operation) : operation(operation) { + // Create unique ids for each block. + operation->walk([&](Operation *nested) { + if (blockIds.count(nested->getBlock()) > 0) + return; + blockIds.insert({nested->getBlock(), blockIds.size()}); + }); + } + + /// Prints dominance information of all blocks. + template + void printDominance(DominanceT &dominanceInfo) { + DenseSet parentVisited; + operation->walk([&](Operation *op) { + Block *block = op->getBlock(); + if (!parentVisited.insert(block).second) + return; + + DenseSet visited; + operation->walk([&](Operation *nested) { + Block *nestedBlock = nested->getBlock(); + if (!visited.insert(nestedBlock).second) + return; + llvm::errs() << "Nearest(" << blockIds[block] << ", " + << blockIds[nestedBlock] << ") = "; + Block *dom = + dominanceInfo.findNearestCommonDominator(block, nestedBlock); + if (dom) + llvm::errs() << blockIds[dom]; + else + llvm::errs() << ""; + llvm::errs() << "\n"; + }); + }); + } + +private: + Operation *operation; + DenseMap blockIds; +}; + +struct TestDominancePass : public FunctionPass { + + void runOnFunction() override { + llvm::errs() << "Testing : " << getFunction().getName() << "\n"; + DominanceTest dominanceTest(getFunction()); + + // Print dominance information. + llvm::errs() << "--- DominanceInfo ---\n"; + dominanceTest.printDominance(getAnalysis()); + + llvm::errs() << "--- PostDominanceInfo ---\n"; + dominanceTest.printDominance(getAnalysis()); + } +}; + +} // end anonymous namespace + +namespace mlir { +void registerTestDominancePass() { + PassRegistration( + "test-print-dominance", + "Print the dominance information for multiple regions."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -41,6 +41,7 @@ void registerTestAllReduceLoweringPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); +void registerTestDominancePass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); void registerTestLinalgTransforms(); @@ -95,6 +96,7 @@ registerTestAllReduceLoweringPass(); registerTestCallGraphPass(); registerTestConstantFold(); + registerTestDominancePass(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); registerTestLinalgTransforms();