diff --git a/mlir/include/mlir/IR/Iterators.h b/mlir/include/mlir/IR/Iterators.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Iterators.h @@ -0,0 +1,76 @@ +//===- Iterators.h - IR iterators for IR visitors ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The iterators defined in this file can be used together with IR visitors. +// Note: These iterators cannot be defined in Visitors.h because that would +// introduce a cyclic header dependency due to Operation.h. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ITERATORS_H +#define MLIR_IR_ITERATORS_H + +#include "mlir/IR/Operation.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "mlir/IR/RegionKindInterface.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/DepthFirstIterator.h" + +namespace mlir { +/// This iterator enumerates elements in "reverse" order. It is a wrapper around +/// llvm::reverse. +struct ReverseIterator { + // llvm::reverse uses RangeT::rbegin and RangeT::rend. + template + static constexpr auto makeIterable(RangeT &&range) { + return llvm::reverse( + ForwardIterator::makeIterable(std::forward(range))); + } +}; + +/// This iterator enumerates elements according to their dominance relationship. +/// Operations and regions are enumerated in "forward" order. Blocks are +/// enumerated according to their successor relationship. Unreachable blocks are +/// not enumerated. +/// +/// Note: If `NoGraphRegions` is set to "true", this iterator asserts that each +/// visited region has SSA dominance. In either case, the ops in such regions +/// are visited in forward order, but for regions without SSA dominance this +/// does not gurantee that defining ops are visited before their users. +template +struct ForwardDominanceIterator { + static Block &makeIterable(Block &range) { + return ForwardIterator::makeIterable(range); + } + + static auto makeIterable(Region ®ion) { + if (NoGraphRegions) { + // Only regions with SSA dominance are allowed. + assert(mayHaveSSADominance(region) && "graph regions are not allowed"); + } + + // Create DFS iterator. Blocks are enumerated according to their successor + // relationship. + Block *null = nullptr; + auto it = region.empty() + ? llvm::make_range(llvm::df_end(null), llvm::df_end(null)) + : llvm::depth_first(®ion.front()); + + // Walk API expects Block references instead of pointers. + auto dereference = [](Block *b) -> Block & { return *b; }; + return llvm::map_range(it, dereference); + } + + static MutableArrayRef makeIterable(Operation &range) { + return ForwardIterator::makeIterable(range); + } +}; +} // namespace mlir + +#endif // MLIR_IR_ITERATORS_H diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h --- a/mlir/include/mlir/IR/RegionKindInterface.h +++ b/mlir/include/mlir/IR/RegionKindInterface.h @@ -38,6 +38,11 @@ }; } // namespace OpTrait +/// Return "true" if the given region may have SSA dominance. This function also +/// returns "true" in case the owner op is an unregistered op or an op that does +/// not implement the RegionKindInterface. +bool mayHaveSSADominance(Region ®ion); + } // namespace mlir #include "mlir/IR/RegionKindInterface.h.inc" diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -74,17 +74,6 @@ } }; -/// This iterator enumerates elements in "reverse" order. It is a wrapper around -/// llvm::reverse. -struct ReverseIterator { - template - static constexpr auto makeIterable(RangeT &&range) { - // llvm::reverse uses RangeT::rbegin and RangeT::rend. - return llvm::reverse( - ForwardIterator::makeIterable(std::forward(range))); - } -}; - /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker /// which invokes the callback on an operation before/after all its attached diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp --- a/mlir/lib/IR/RegionKindInterface.cpp +++ b/mlir/lib/IR/RegionKindInterface.cpp @@ -16,3 +16,13 @@ using namespace mlir; #include "mlir/IR/RegionKindInterface.cpp.inc" + +bool mlir::mayHaveSSADominance(Region ®ion) { + Operation *parent = region.getParentOp(); + if (!parent) + return true; + auto regionKindOp = dyn_cast(parent); + if (!regionKindOp) + return true; + return regionKindOp.hasSSADominance(region.getRegionNumber()); +} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -27,55 +28,6 @@ #define DEBUG_TYPE "dialect-conversion" -/// Recursively collect all of the operations to convert from within 'region'. -/// If 'target' is nonnull, operations that are recursively legal have their -/// regions pre-filtered to avoid considering them for legalization. -static LogicalResult -computeConversionSet(iterator_range region, - Location regionLoc, - SmallVectorImpl &toConvert, - ConversionTarget *target = nullptr) { - if (region.empty()) - return success(); - - // Traverse starting from the entry block. - SmallVector worklist(1, &*region.begin()); - DenseSet visitedBlocks; - visitedBlocks.insert(worklist.front()); - while (!worklist.empty()) { - Block *block = worklist.pop_back_val(); - - // Compute the conversion set of each of the nested operations. - for (Operation &op : *block) { - toConvert.emplace_back(&op); - - // Don't check this operation's children for conversion if the operation - // is recursively legal. - auto legalityInfo = - target ? target->isLegal(&op) - : std::optional(); - if (legalityInfo && legalityInfo->isRecursivelyLegal) - continue; - for (auto ®ion : op.getRegions()) { - if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), - toConvert, target))) - return failure(); - } - } - - // Recurse to children that haven't been visited. - for (Block *succ : block->getSuccessors()) - if (visitedBlocks.insert(succ).second) - worklist.push_back(succ); - } - - // Check that all blocks in the region were visited. - if (llvm::any_of(llvm::drop_begin(region, 1), - [&](Block &block) { return !visitedBlocks.count(&block); })) - return emitError(regionLoc, "unreachable blocks were not converted"); - return success(); -} - /// A utility function to log a successful result for the given reason. template static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { @@ -952,10 +904,6 @@ void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before); - /// Notifies that the blocks of a region were cloned into another. - void notifyRegionWasClonedBefore(iterator_range &blocks, - Location origRegionLoc); - /// Notifies that a pattern match failed for the given reason. LogicalResult notifyMatchFailure(Location loc, @@ -1463,20 +1411,6 @@ blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr})); } -void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( - iterator_range &blocks, Location origRegionLoc) { - for (Block &block : blocks) - blockActions.push_back(BlockAction::getCreate(&block)); - - // Compute the conversion set for the inlined region. - auto result = computeConversionSet(blocks, origRegionLoc, createdOps); - - // This original region has already had its conversion set computed, so there - // shouldn't be any new failures. - (void)result; - assert(succeeded(result) && "expected region to have no unreachable blocks"); -} - LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( Location loc, function_ref reasonCallback) { LLVM_DEBUG({ @@ -1630,12 +1564,15 @@ IRMapping &mapping) { if (region.empty()) return; + PatternRewriter::cloneRegionBefore(region, parent, before, mapping); - // Collect the range of the cloned blocks. - auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); - auto clonedBlocks = llvm::make_range(clonedBeginIt, before); - impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); + for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) { + Block *cloned = mapping.lookup(&b); + impl->notifyCreatedBlock(cloned); + cloned->walk>( + [&](Operation *op) { notifyOperationInserted(op); }); + } } void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { @@ -2444,11 +2381,16 @@ // Compute the set of operations and blocks to convert. SmallVector toConvert; for (auto *op : ops) { - toConvert.emplace_back(op); - for (auto ®ion : op->getRegions()) - if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), - toConvert, &target))) - return failure(); + op->walk>( + [&](Operation *op) { + toConvert.push_back(op); + // Don't check this operation's children for conversion if the + // operation is recursively legal. + auto legalityInfo = target.isLegal(op); + if (legalityInfo && legalityInfo->isRecursivelyLegal) + return WalkResult::skip(); + return WalkResult::advance(); + }); } // Convert each operation and discard rewrites on failure. diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir --- a/mlir/test/IR/visitors.mlir +++ b/mlir/test/IR/visitors.mlir @@ -250,3 +250,85 @@ // CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' // CHECK: Erasing block ^bb0 from region 0 from operation 'func.func' // CHECK: Erasing block ^bb0 from region 0 from operation 'builtin.module' + +// ----- + +func.func @unordered_cfg_with_loop() { + "regionOp0"() ({ + ^bb0: + %c = "op0"() : () -> (i1) + cf.cond_br %c, ^bb2, ^bb3 + ^bb1: + "op1"(%val) : (i32) -> () + cf.br ^bb5 + ^bb2: + %val = "op2"() : () -> (i32) + cf.br ^bb1 + ^bb3: + "op3"() : () -> () + cf.br ^bb2 + ^bb4: + "op4"() : () -> () + cf.br ^bb2 + ^bb5: + "op5"() : () -> () + cf.br ^bb7 + ^bb6: + "op6"() : () -> () + cf.br ^bb6 + ^bb7: + "op7"() : () -> () + }) : () -> () + return +} + +// 4 +// | +// v +// 0 -> 2 --> 1 --> 5 --> 7 +// | ^ +// | | 6 -- +// | / ^ \ +// | / \ / +// v / -- +// 3 + +// CHECK-LABEL: Op forward dominance post-order visits +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'cf.cond_br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op5' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op7' +// CHECK: Visiting op 'op3' +// CHECK: Visiting op 'cf.br' +// CHECK-NOT: Visiting op 'op6' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'func.return' +// CHECK: Visiting op 'func.func' + +// CHECK-LABEL: Block forward dominance post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb5 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb7 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb3 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func.func' + +// CHECK-LABEL: Region forward dominance post-order visits +// CHECK: Visiting region 0 from operation 'regionOp0' +// CHECK: Visiting region 0 from operation 'func.func' + +// CHECK-LABEL: Block pre-order erasures (skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Cannot erase block ^bb0 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb1 from region 0 from operation 'regionOp0', still has uses +// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Cannot erase block ^bb2 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb3 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb4 from region 0 from operation 'regionOp0', still has uses diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp --- a/mlir/test/lib/IR/TestVisitors.cpp +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/IR/Iterators.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -74,6 +76,23 @@ llvm::outs() << "Region reverse post-order visits" << "\n"; op->walk(regionPure); + + // This test case tests "NoGraphRegions = true", so start the walk with + // functions. + op->walk([&](FunctionOpInterface funcOp) { + llvm::outs() << "Op forward dominance post-order visits" + << "\n"; + funcOp->walk>(opPure); + llvm::outs() << "Block forward dominance post-order visits" + << "\n"; + funcOp->walk>(blockPure); + llvm::outs() << "Region forward dominance post-order visits" + << "\n"; + funcOp->walk>(regionPure); + }); } /// Tests erasure callbacks that skip the walk. @@ -98,11 +117,18 @@ if (isa(parentOp) || isa(parentOp->getParentOp())) return WalkResult::advance(); - llvm::outs() << "Erasing "; - printBlock(block); - llvm::outs() << "\n"; - block->erase(); - return WalkResult::skip(); + if (block->use_empty()) { + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + return WalkResult::skip(); + } else { + llvm::outs() << "Cannot erase "; + printBlock(block); + llvm::outs() << ", still has uses\n"; + return WalkResult::advance(); + } }; llvm::outs() << "Op pre-order erasures (skip)" @@ -141,10 +167,16 @@ op->erase(); }; auto noSkipBlockErasure = [](Block *block) { - llvm::outs() << "Erasing "; - printBlock(block); - llvm::outs() << "\n"; - block->erase(); + if (block->use_empty()) { + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + } else { + llvm::outs() << "Cannot erase "; + printBlock(block); + llvm::outs() << ", still has uses\n"; + } }; llvm::outs() << "Op post-order erasures (no skip)"