diff --git a/mlir/include/mlir/IR/Iterators.h b/mlir/include/mlir/IR/Iterators.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Iterators.h @@ -0,0 +1,97 @@ +//===- Iterators.h - IR iterators for IR visitors ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The iterators defined in this file can be used together with IR visitors. +// Note: These iterators cannot be defined in Visitors.h because that would +// introduce a cyclic header dependency due to Operation.h. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_ITERATORS_H +#define MLIR_IR_ITERATORS_H + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +namespace mlir { +/// This iterator enumerates elements in "reverse" order. It is a wrapper around +/// llvm::reverse. +struct ReverseIterator { + // llvm::reverse uses RangeT::rbegin and RangeT::rend. + template + static constexpr auto makeIterable(RangeT &&range) { + return llvm::reverse( + ForwardIterator::makeIterable(std::forward(range))); + } +}; + +/// An iterator that enumerates blocks according to their successor +/// relationship. +class ForwardDominanceBlockIterator final + : public llvm::iterator_facade_base { +public: + /// Create an "end" iterator. + ForwardDominanceBlockIterator(); + + explicit ForwardDominanceBlockIterator( + iterator_range blocks); + + explicit ForwardDominanceBlockIterator(Region *region) + : ForwardDominanceBlockIterator(region->getBlocks()) {} + + using llvm::iterator_facade_base::operator++; + + ForwardDominanceBlockIterator &operator++(); + Block *operator->() const { return worklist.back(); } + Block &operator*() const { return *worklist.back(); } + + /// Compare this iterator with another iterator. + bool operator==(const ForwardDominanceBlockIterator &rhs) const; + bool operator!=(const ForwardDominanceBlockIterator &rhs) const { + return !(*this == rhs); + } + +private: + /// A worklist that contains the blocks that will be enumerated next upon + /// iterator increment. The last block on the list is the current position + /// of the iterator. An iterator with an empty worklist is an "end" iterator. + SmallVector worklist; + + /// Set of blocks that have not been visited yet. + SetVector remainingBlocks; +}; + +/// This iterator enumerates elements according to their dominance relationship. +/// Operations and regions are enumerated in "forward" order. Blocks are +/// enumerated according to their successor relationship. +struct ForwardDominanceIterator { + static Block &makeIterable(Block &range) { + return ForwardIterator::makeIterable(range); + } + + static auto makeIterable(Region ®ion) { + return llvm::make_range(ForwardDominanceBlockIterator(®ion), {}); + } + + static auto makeIterable(iterator_range range) { + return llvm::make_range(ForwardDominanceBlockIterator(range), {}); + } + + static MutableArrayRef makeIterable(Operation &range) { + return ForwardIterator::makeIterable(range); + } +}; +} // namespace mlir + +#endif // MLIR_IR_ITERATORS_H diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -73,17 +73,6 @@ } }; -/// This iterator enumerates elements in "reverse" order. It is a wrapper around -/// llvm::reverse. -struct ReverseIterator { - template - static constexpr auto makeIterable(RangeT &&range) { - // llvm::reverse uses RangeT::rbegin and RangeT::rend. - return llvm::reverse( - ForwardIterator::makeIterable(std::forward(range))); - } -}; - /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker /// which invokes the callback on an operation before/after all its attached diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -20,6 +20,7 @@ FunctionImplementation.cpp FunctionInterfaces.cpp IntegerSet.cpp + Iterators.cpp Location.cpp MLIRContext.cpp Operation.cpp diff --git a/mlir/lib/IR/Iterators.cpp b/mlir/lib/IR/Iterators.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/Iterators.cpp @@ -0,0 +1,52 @@ +//===- Iterators.cpp - IR iterators for IR visitors -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Iterators.h" + +using namespace mlir; + +ForwardDominanceBlockIterator::ForwardDominanceBlockIterator() = default; + +ForwardDominanceBlockIterator::ForwardDominanceBlockIterator( + iterator_range blocks) { + if (blocks.empty()) + return; + + // Always start iterating with the first block. + worklist.push_back(&*blocks.begin()); + for (Block &b : llvm::drop_begin(blocks)) + remainingBlocks.insert(&b); +} + +ForwardDominanceBlockIterator &ForwardDominanceBlockIterator::operator++() { + assert(!worklist.empty() && "cannot increment iterator past end"); + Block *block = worklist.pop_back_val(); + + // Add successors to worklist. + for (Block *succ : block->getSuccessors()) { + auto it = llvm::find(remainingBlocks, succ); + if (it != remainingBlocks.end()) { + remainingBlocks.erase(it); + worklist.push_back(succ); + } + } + + // Take the next block from the set of unvisited blocks if we run out of + // blocks. + if (worklist.empty() && !remainingBlocks.empty()) { + worklist.push_back(remainingBlocks.pop_back_val()); + } + + return *this; +} + +bool ForwardDominanceBlockIterator::operator==( + const ForwardDominanceBlockIterator &rhs) const { + return llvm::equal(worklist, rhs.worklist) && + llvm::equal(remainingBlocks, rhs.remainingBlocks); +} diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -27,55 +28,6 @@ #define DEBUG_TYPE "dialect-conversion" -/// Recursively collect all of the operations to convert from within 'region'. -/// If 'target' is nonnull, operations that are recursively legal have their -/// regions pre-filtered to avoid considering them for legalization. -static LogicalResult -computeConversionSet(iterator_range region, - Location regionLoc, - SmallVectorImpl &toConvert, - ConversionTarget *target = nullptr) { - if (region.empty()) - return success(); - - // Traverse starting from the entry block. - SmallVector worklist(1, &*region.begin()); - DenseSet visitedBlocks; - visitedBlocks.insert(worklist.front()); - while (!worklist.empty()) { - Block *block = worklist.pop_back_val(); - - // Compute the conversion set of each of the nested operations. - for (Operation &op : *block) { - toConvert.emplace_back(&op); - - // Don't check this operation's children for conversion if the operation - // is recursively legal. - auto legalityInfo = - target ? target->isLegal(&op) - : std::optional(); - if (legalityInfo && legalityInfo->isRecursivelyLegal) - continue; - for (auto ®ion : op.getRegions()) { - if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), - toConvert, target))) - return failure(); - } - } - - // Recurse to children that haven't been visited. - for (Block *succ : block->getSuccessors()) - if (visitedBlocks.insert(succ).second) - worklist.push_back(succ); - } - - // Check that all blocks in the region were visited. - if (llvm::any_of(llvm::drop_begin(region, 1), - [&](Block &block) { return !visitedBlocks.count(&block); })) - return emitError(regionLoc, "unreachable blocks were not converted"); - return success(); -} - /// A utility function to log a successful result for the given reason. template static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { @@ -952,10 +904,6 @@ void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before); - /// Notifies that the blocks of a region were cloned into another. - void notifyRegionWasClonedBefore(iterator_range &blocks, - Location origRegionLoc); - /// Notifies that a pattern match failed for the given reason. LogicalResult notifyMatchFailure(Location loc, @@ -1463,20 +1411,6 @@ blockActions.push_back(BlockAction::getMove(laterBlock, {®ion, nullptr})); } -void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( - iterator_range &blocks, Location origRegionLoc) { - for (Block &block : blocks) - blockActions.push_back(BlockAction::getCreate(&block)); - - // Compute the conversion set for the inlined region. - auto result = computeConversionSet(blocks, origRegionLoc, createdOps); - - // This original region has already had its conversion set computed, so there - // shouldn't be any new failures. - (void)result; - assert(succeeded(result) && "expected region to have no unreachable blocks"); -} - LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( Location loc, function_ref reasonCallback) { LLVM_DEBUG({ @@ -1630,12 +1564,15 @@ IRMapping &mapping) { if (region.empty()) return; + PatternRewriter::cloneRegionBefore(region, parent, before, mapping); - // Collect the range of the cloned blocks. - auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); - auto clonedBlocks = llvm::make_range(clonedBeginIt, before); - impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); + for (Block &b : ForwardDominanceIterator::makeIterable(region)) { + Block *cloned = mapping.lookup(&b); + impl->notifyCreatedBlock(cloned); + cloned->walk( + [&](Operation *op) { notifyOperationInserted(op); }); + } } void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { @@ -2444,11 +2381,15 @@ // Compute the set of operations and blocks to convert. SmallVector toConvert; for (auto *op : ops) { - toConvert.emplace_back(op); - for (auto ®ion : op->getRegions()) - if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), - toConvert, &target))) - return failure(); + op->walk([&](Operation *op) { + toConvert.push_back(op); + // Don't check this operation's children for conversion if the operation + // is recursively legal. + auto legalityInfo = target.isLegal(op); + if (legalityInfo && legalityInfo->isRecursivelyLegal) + return WalkResult::skip(); + return WalkResult::advance(); + }); } // Convert each operation and discard rewrites on failure. diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir --- a/mlir/test/IR/visitors.mlir +++ b/mlir/test/IR/visitors.mlir @@ -250,3 +250,93 @@ // CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' // CHECK: Erasing block ^bb0 from region 0 from operation 'func.func' // CHECK: Erasing block ^bb0 from region 0 from operation 'builtin.module' + +// ----- + +func.func @unordered_cfg_with_loop() { + "regionOp0"() ({ + ^bb0: + %c = "op0"() : () -> (i1) + cf.cond_br %c, ^bb2, ^bb3 + ^bb1: + "op1"(%val) : (i32) -> () + cf.br ^bb5 + ^bb2: + %val = "op2"() : () -> (i32) + cf.br ^bb1 + ^bb3: + "op3"() : () -> () + cf.br ^bb2 + ^bb4: + "op4"() : () -> () + cf.br ^bb2 + ^bb5: + "op5"() : () -> () + cf.br ^bb7 + ^bb6: + "op6"() : () -> () + cf.br ^bb6 + ^bb7: + "op7"() : () -> () + }) : () -> () + return +} + +// 4 +// | +// v +// 0 -> 2 --> 1 --> 5 --> 7 +// | ^ +// | | 6 -- +// | / ^ \ +// | / \ / +// v / -- +// 3 + +// CHECK-LABEL: Op forward dominance post-order visits +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'cf.cond_br' +// CHECK: Visiting op 'op3' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op5' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op7' +// CHECK: Visiting op 'op6' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op4' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'func.return' +// CHECK: Visiting op 'func.func' +// CHECK: Visiting op 'builtin.module' + +// CHECK-LABEL: Block forward dominance post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb3 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb5 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb7 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb6 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb4 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func.func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.module' + +// CHECK-LABEL: Region forward dominance post-order visits +// CHECK: Visiting region 0 from operation 'regionOp0' +// CHECK: Visiting region 0 from operation 'func.func' +// CHECK: Visiting region 0 from operation 'builtin.module' + +// CHECK-LABEL: Block pre-order erasures (skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Cannot erase block ^bb0 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb1 from region 0 from operation 'regionOp0', still has uses +// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Cannot erase block ^bb2 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb3 from region 0 from operation 'regionOp0', still has uses +// CHECK: Cannot erase block ^bb4 from region 0 from operation 'regionOp0', still has uses diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp --- a/mlir/test/lib/IR/TestVisitors.cpp +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Iterators.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -74,6 +75,16 @@ llvm::outs() << "Region reverse post-order visits" << "\n"; op->walk(regionPure); + + llvm::outs() << "Op forward dominance post-order visits" + << "\n"; + op->walk(opPure); + llvm::outs() << "Block forward dominance post-order visits" + << "\n"; + op->walk(blockPure); + llvm::outs() << "Region forward dominance post-order visits" + << "\n"; + op->walk(regionPure); } /// Tests erasure callbacks that skip the walk. @@ -98,11 +109,18 @@ if (isa(parentOp) || isa(parentOp->getParentOp())) return WalkResult::advance(); - llvm::outs() << "Erasing "; - printBlock(block); - llvm::outs() << "\n"; - block->erase(); - return WalkResult::skip(); + if (block->use_empty()) { + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + return WalkResult::skip(); + } else { + llvm::outs() << "Cannot erase "; + printBlock(block); + llvm::outs() << ", still has uses\n"; + return WalkResult::advance(); + } }; llvm::outs() << "Op pre-order erasures (skip)" @@ -141,10 +159,16 @@ op->erase(); }; auto noSkipBlockErasure = [](Block *block) { - llvm::outs() << "Erasing "; - printBlock(block); - llvm::outs() << "\n"; - block->erase(); + if (block->use_empty()) { + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + } else { + llvm::outs() << "Cannot erase "; + printBlock(block); + llvm::outs() << ", still has uses\n"; + } }; llvm::outs() << "Op post-order erasures (no skip)"