diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -64,8 +64,12 @@ /// This iterator enumerates the elements in "forward" order. struct ForwardIterator { - template - static constexpr RangeT &makeRange(RangeT &range) { + /// Make operations iterable: return the list of regions. + static MutableArrayRef makeIterable(Operation &range); + + /// Regions and block are already iterable. + template + static constexpr T &makeIterable(T &range) { return range; } }; @@ -74,9 +78,10 @@ /// llvm::reverse. struct ReverseIterator { template - static constexpr auto makeRange(RangeT &&range) { + static constexpr auto makeIterable(RangeT &&range) { // llvm::reverse uses RangeT::rbegin and RangeT::rend. - return llvm::reverse(std::forward(range)); + return llvm::reverse( + ForwardIterator::makeIterable(std::forward(range))); } }; @@ -141,12 +146,58 @@ /// pre-order erasure. template void walk(Operation *op, function_ref callback, - WalkOrder order); + WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. + for (auto ®ion : Iterator::makeIterable(*op)) { + if (order == WalkOrder::PreOrder) + callback(®ion); + for (auto &block : Iterator::makeIterable(region)) { + for (auto &nestedOp : Iterator::makeIterable(block)) + walk(&nestedOp, callback, order); + } + if (order == WalkOrder::PostOrder) + callback(®ion); + } +} + template -void walk(Operation *op, function_ref callback, WalkOrder order); +void walk(Operation *op, function_ref callback, + WalkOrder order) { + for (auto ®ion : Iterator::makeIterable(*op)) { + // Early increment here in the case where the block is erased. + for (auto &block : + llvm::make_early_inc_range(Iterator::makeIterable(region))) { + if (order == WalkOrder::PreOrder) + callback(&block); + for (auto &nestedOp : Iterator::makeIterable(block)) + walk(&nestedOp, callback, order); + if (order == WalkOrder::PostOrder) + callback(&block); + } + } +} + template void walk(Operation *op, function_ref callback, - WalkOrder order); + WalkOrder order) { + if (order == WalkOrder::PreOrder) + callback(op); + + // TODO: This walk should be iterative over the operations. + for (auto ®ion : Iterator::makeIterable(*op)) { + for (auto &block : Iterator::makeIterable(region)) { + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeIterable(block))) + walk(&nestedOp, callback, order); + } + } + + if (order == WalkOrder::PostOrder) + callback(op); +} + /// Walk all of the regions, blocks, or operations nested under (and including) /// the given operation. The order in which regions, blocks and operations at /// the same nesting level are visited (e.g., lexicographical or reverse @@ -159,13 +210,88 @@ /// * the walk is in pre-order and the walk is skipped after the erasure. template WalkResult walk(Operation *op, function_ref callback, - WalkOrder order); + WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. + for (auto ®ion : Iterator::makeIterable(*op)) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(®ion); + if (result.wasSkipped()) + continue; + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } + for (auto &block : Iterator::makeIterable(region)) { + for (auto &nestedOp : Iterator::makeIterable(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) + return WalkResult::interrupt(); + } + if (order == WalkOrder::PostOrder) { + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); + // We don't check if this region was skipped because its walk already + // finished and the walk will continue with the next region. + } + } + return WalkResult::advance(); +} + template WalkResult walk(Operation *op, function_ref callback, - WalkOrder order); + WalkOrder order) { + for (auto ®ion : Iterator::makeIterable(*op)) { + // Early increment here in the case where the block is erased. + for (auto &block : + llvm::make_early_inc_range(Iterator::makeIterable(region))) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(&block); + if (result.wasSkipped()) + continue; + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } + for (auto &nestedOp : Iterator::makeIterable(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) + return WalkResult::interrupt(); + if (order == WalkOrder::PostOrder) { + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); + // We don't check if this block was skipped because its walk already + // finished and the walk will continue with the next block. + } + } + } + return WalkResult::advance(); +} + template WalkResult walk(Operation *op, function_ref callback, - WalkOrder order); + WalkOrder order) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(op); + // If skipped, caller will continue the walk on the next operation. + if (result.wasSkipped()) + return WalkResult::advance(); + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } + + // TODO: This walk should be iterative over the operations. + for (auto ®ion : Iterator::makeIterable(*op)) { + for (auto &block : Iterator::makeIterable(region)) { + // Early increment here in the case where the operation is erased. + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeIterable(block))) { + if (walk(&nestedOp, callback, order).wasInterrupted()) + return WalkResult::interrupt(); + } + } + } + + if (order == WalkOrder::PostOrder) + return callback(op); + return WalkResult::advance(); +} // Below are a set of functions to walk nested operations. Users should favor // the direct `walk` methods on the IR classes(Operation/Block/etc) over these diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -14,92 +14,9 @@ WalkStage::WalkStage(Operation *op) : numRegions(op->getNumRegions()), nextRegion(0) {} -/// Walk all of the regions/blocks/operations nested under and including the -/// given operation. The order in which regions, blocks and operations at the -/// same nesting level are visited (e.g., lexicographical or reverse -/// lexicographical order) is determined by 'Iterator'. The walk order for -/// enclosing regions, blocks and operations with respect to their nested ones -/// is specified by 'order'. These methods are invoked for void-returning -/// callbacks. A callback on a block or operation is allowed to erase that block -/// or operation only if the walk is in post-order. See non-void method for -/// pre-order erasure. -template -void detail::walk(Operation *op, function_ref callback, - WalkOrder order) { - // We don't use early increment for regions because they can't be erased from - // a callback. - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - if (order == WalkOrder::PreOrder) - callback(®ion); - for (auto &block : Iterator::makeRange(region)) { - for (auto &nestedOp : Iterator::makeRange(block)) - walk(&nestedOp, callback, order); - } - if (order == WalkOrder::PostOrder) - callback(®ion); - } -} -// Explicit template instantiations for all supported iterators. -template void detail::walk(Operation *, - function_ref, - WalkOrder); -template void detail::walk(Operation *, - function_ref, - WalkOrder); - -template -void detail::walk(Operation *op, function_ref callback, - WalkOrder order) { - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - // Early increment here in the case where the block is erased. - for (auto &block : - llvm::make_early_inc_range(Iterator::makeRange(region))) { - if (order == WalkOrder::PreOrder) - callback(&block); - for (auto &nestedOp : Iterator::makeRange(block)) - walk(&nestedOp, callback, order); - if (order == WalkOrder::PostOrder) - callback(&block); - } - } -} -// Explicit template instantiations for all supported iterators. -template void detail::walk(Operation *, - function_ref, - WalkOrder); -template void detail::walk(Operation *, - function_ref, - WalkOrder); - -template -void detail::walk(Operation *op, function_ref callback, - WalkOrder order) { - if (order == WalkOrder::PreOrder) - callback(op); - - // TODO: This walk should be iterative over the operations. - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - for (auto &block : Iterator::makeRange(region)) { - // Early increment here in the case where the operation is erased. - for (auto &nestedOp : - llvm::make_early_inc_range(Iterator::makeRange(block))) - walk(&nestedOp, callback, order); - } - } - - if (order == WalkOrder::PostOrder) - callback(op); +MutableArrayRef ForwardIterator::makeIterable(Operation &range) { + return range.getRegions(); } -// Explicit template instantiations for all supported iterators. -template void detail::walk(Operation *, - function_ref, - WalkOrder); -template void detail::walk(Operation *, - function_ref, - WalkOrder); void detail::walk(Operation *op, function_ref callback) { @@ -120,128 +37,6 @@ callback(op, stage); } -/// Walk all of the regions/blocks/operations nested under and including the -/// given operation. These functions walk operations until an interrupt result -/// is returned by the callback. Walks on regions, blocks and operations may -/// also be skipped if the callback returns a skip result. Regions, blocks and -/// operations at the same nesting level are visited in lexicographical order. -/// The walk order for enclosing regions, blocks and operations with respect to -/// their nested ones is specified by 'order'. A callback on a block or -/// operation is allowed to erase that block or operation if either: -/// * the walk is in post-order, or -/// * the walk is in pre-order and the walk is skipped after the erasure. -template -WalkResult detail::walk(Operation *op, - function_ref callback, - WalkOrder order) { - // We don't use early increment for regions because they can't be erased from - // a callback. - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - if (order == WalkOrder::PreOrder) { - WalkResult result = callback(®ion); - if (result.wasSkipped()) - continue; - if (result.wasInterrupted()) - return WalkResult::interrupt(); - } - for (auto &block : Iterator::makeRange(region)) { - for (auto &nestedOp : Iterator::makeRange(block)) - if (walk(&nestedOp, callback, order).wasInterrupted()) - return WalkResult::interrupt(); - } - if (order == WalkOrder::PostOrder) { - if (callback(®ion).wasInterrupted()) - return WalkResult::interrupt(); - // We don't check if this region was skipped because its walk already - // finished and the walk will continue with the next region. - } - } - return WalkResult::advance(); -} -// Explicit template instantiations for all supported iterators. -template WalkResult -detail::walk(Operation *, function_ref, - WalkOrder); -template WalkResult -detail::walk(Operation *, function_ref, - WalkOrder); - -template -WalkResult detail::walk(Operation *op, - function_ref callback, - WalkOrder order) { - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - // Early increment here in the case where the block is erased. - for (auto &block : - llvm::make_early_inc_range(Iterator::makeRange(region))) { - if (order == WalkOrder::PreOrder) { - WalkResult result = callback(&block); - if (result.wasSkipped()) - continue; - if (result.wasInterrupted()) - return WalkResult::interrupt(); - } - for (auto &nestedOp : Iterator::makeRange(block)) - if (walk(&nestedOp, callback, order).wasInterrupted()) - return WalkResult::interrupt(); - if (order == WalkOrder::PostOrder) { - if (callback(&block).wasInterrupted()) - return WalkResult::interrupt(); - // We don't check if this block was skipped because its walk already - // finished and the walk will continue with the next block. - } - } - } - return WalkResult::advance(); -} -// Explicit template instantiations for all supported iterators. -template WalkResult -detail::walk(Operation *, function_ref, - WalkOrder); -template WalkResult -detail::walk(Operation *, function_ref, - WalkOrder); - -template -WalkResult detail::walk(Operation *op, - function_ref callback, - WalkOrder order) { - if (order == WalkOrder::PreOrder) { - WalkResult result = callback(op); - // If skipped, caller will continue the walk on the next operation. - if (result.wasSkipped()) - return WalkResult::advance(); - if (result.wasInterrupted()) - return WalkResult::interrupt(); - } - - // TODO: This walk should be iterative over the operations. - MutableArrayRef regions = op->getRegions(); - for (auto ®ion : Iterator::makeRange(regions)) { - for (auto &block : Iterator::makeRange(region)) { - // Early increment here in the case where the operation is erased. - for (auto &nestedOp : - llvm::make_early_inc_range(Iterator::makeRange(block))) { - if (walk(&nestedOp, callback, order).wasInterrupted()) - return WalkResult::interrupt(); - } - } - } - - if (order == WalkOrder::PostOrder) - return callback(op); - return WalkResult::advance(); -} -// Explicit template instantiations for all supported iterators. -template WalkResult -detail::walk(Operation *, - function_ref, WalkOrder); -template WalkResult -detail::walk(Operation *, - function_ref, WalkOrder); - WalkResult detail::walk( Operation *op, function_ref callback) { diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir --- a/mlir/test/IR/visitors.mlir +++ b/mlir/test/IR/visitors.mlir @@ -71,6 +71,23 @@ // CHECK: Visiting region 0 from operation 'func.func' // CHECK: Visiting region 0 from operation 'builtin.module' +// CHECK-LABEL: Op reverse post-order visits +// CHECK: Visiting op 'func.return' +// CHECK: Visiting op 'scf.yield' +// CHECK: Visiting op 'use3' +// CHECK: Visiting op 'scf.yield' +// CHECK: Visiting op 'use2' +// CHECK: Visiting op 'scf.yield' +// CHECK: Visiting op 'use1' +// CHECK: Visiting op 'scf.if' +// CHECK: Visiting op 'use0' +// CHECK: Visiting op 'scf.for' +// CHECK: Visiting op 'arith.constant' +// CHECK: Visiting op 'arith.constant' +// CHECK: Visiting op 'arith.constant' +// CHECK: Visiting op 'func.func' +// CHECK: Visiting op 'builtin.module' + // CHECK-LABEL: Op pre-order erasures // CHECK: Erasing op 'scf.for' // CHECK: Erasing op 'func.return' @@ -172,6 +189,29 @@ // CHECK: Visiting region 0 from operation 'func.func' // CHECK: Visiting region 0 from operation 'builtin.module' +// CHECK-LABEL: Op reverse post-order visits +// CHECK: Visiting op 'func.return' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'cf.br' +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'func.func' +// CHECK: Visiting op 'builtin.module' + +// CHECK-LABEL: Block reverse post-order visits +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func.func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.module' + +// CHECK-LABEL: Region reverse post-order visits +// CHECK: Visiting region 0 from operation 'regionOp0' +// CHECK: Visiting region 0 from operation 'func.func' +// CHECK: Visiting region 0 from operation 'builtin.module' + // CHECK-LABEL: Op pre-order erasures (skip) // CHECK: Erasing op 'regionOp0' // CHECK: Erasing op 'func.return' diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp --- a/mlir/test/lib/IR/TestVisitors.cpp +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -64,6 +64,16 @@ llvm::outs() << "Region post-order visits" << "\n"; op->walk(regionPure); + + llvm::outs() << "Op reverse post-order visits" + << "\n"; + op->walk(opPure); + llvm::outs() << "Block reverse post-order visits" + << "\n"; + op->walk(blockPure); + llvm::outs() << "Region reverse post-order visits" + << "\n"; + op->walk(regionPure); } /// Tests erasure callbacks that skip the walk.