diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -258,42 +258,47 @@ /// Walk the operations in this block. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'ItOrder'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > RetT walk(FnT &&callback) { - return walk(begin(), end(), std::forward(callback)); + return walk(begin(), end(), std::forward(callback)); } /// Walk the operations in the specified [begin, end) range of this block. The /// callback method is called for each nested region, block or operation, - /// depending on the callback provided. Regions, blocks and operations at the - /// same nesting level are visited in lexicographical order. The walk order + /// depending on the callback provided. The order in which regions, blocks and + /// operations at the same nesting level are visited (e.g., lexicographical or + /// reverse lexicographical order) is determined by 'ItOrder'. The walk order /// for enclosing regions, blocks and operations with respect to their nested /// ones is specified by 'Order' (post-order by default). This method is /// invoked for void-returning callbacks. A callback on a block or operation /// is allowed to erase that block or operation only if the walk is in /// post-order. See non-void method for pre-order erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walk(&op, callback); + detail::walk(&op, callback); } /// Walk the operations in the specified [begin, end) range of this block. The /// callback method is called for each nested region, block or operation, - /// depending on the callback provided. Regions, blocks and operations at the - /// same nesting level are visited in lexicographical order. The walk order + /// depending on the callback provided. The order in which regions, blocks and + /// operations at the same nesting level are visited (e.g., lexicographical or + /// reverse lexicographical order) is determined by 'ItOrder'. The walk order /// for enclosing regions, blocks and operations with respect to their nested /// ones is specified by 'Order' (post-order by default). This method is /// invoked for skippable or interruptible callbacks. A callback on a block or @@ -301,12 +306,13 @@ /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walk(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -132,20 +132,22 @@ /// Walk the operation by calling the callback for each nested operation /// (including this one), block or region, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations the same nesting level + /// are visited (e.g., lexicographical or reverse lexicographical order) is + /// determined by 'ItOrder'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t>::num_args == 1, RetT> walk(FnT &&callback) { - return state->walk(std::forward(callback)); + return state->walk(std::forward(callback)); } /// Generic walker with a stage aware callback. Walk the operation by calling diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -589,9 +589,10 @@ /// Walk the operation by calling the callback for each nested operation /// (including this one), block or region, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'ItOrder'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or @@ -613,12 +614,13 @@ /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); - template > std::enable_if_t>::num_args == 1, RetT> walk(FnT &&callback) { - return detail::walk(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } /// Generic walker with a stage aware callback. Walk the operation by calling diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -265,37 +265,41 @@ /// Walk the operations in this region. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'ItOrder'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). This method is invoked for void-returning /// callbacks. A callback on a block or operation is allowed to erase that /// block or operation only if the walk is in post-order. See non-void method /// for pre-order erasure. See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(FnT &&callback) { for (auto &block : *this) - block.walk(callback); + block.walk(callback); } /// Walk the operations in this region. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'ItOrder'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). This method is invoked for skippable or /// interruptible callbacks. A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, /// * or the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(FnT &&callback) { for (auto &block : *this) - if (block.walk(callback).wasInterrupted()) + if (block.walk(callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -61,6 +61,7 @@ /// Traversal order for region, block and operation walk utilities. enum class WalkOrder { PreOrder, PostOrder }; +enum class IteratorOrder { Forward, Reverse }; /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker @@ -112,32 +113,58 @@ template using first_argument = decltype(first_argument_type(std::declval())); +/// This iterator enumerates elements in "forward" order. +struct ForwardIterator { + template static auto makeRange(RangeT &&range) { + return llvm::make_range(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +/// This iterator enumerates elements in "reverse" order. It is a wrapper around +/// llvm::reverse. +struct ReverseIterator { + template static auto makeRange(RangeT &&range) { + return llvm::reverse( + llvm::make_range(std::begin(std::forward(range)), + std::end(std::forward(range)))); + } +}; + /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order'. These methods are invoked for void-returning +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. These methods are invoked for void-returning /// callbacks. A callback on a block or operation is allowed to erase that block /// or operation only if the walk is in post-order. See non-void method for /// pre-order erasure. +template void walk(Operation *op, function_ref callback, WalkOrder order); +template void walk(Operation *op, function_ref callback, WalkOrder order); +template void walk(Operation *op, function_ref callback, WalkOrder order); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order'. This method is invoked for skippable or interruptible -/// callbacks. A callback on a block or operation is allowed to erase that block -/// or operation if either: +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. This method is invoked for skippable or +/// interruptible callbacks. A callback on a block or operation is allowed to +/// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); @@ -147,10 +174,11 @@ // upon the type of the callback function. /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'Order' (post-order by default). A callback on a block or +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'IteratorOrder'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'Order' (post-order by default). A callback on a block or /// operation is allowed to erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. @@ -162,20 +190,30 @@ /// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + WalkOrder Order = WalkOrder::PostOrder, + IteratorOrder ItOrder = IteratorOrder::Forward, typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t::value, RetT> walk(Operation *op, FuncTy &&callback) { - return detail::walk(op, function_ref(callback), Order); + if (ItOrder == IteratorOrder::Forward) { + return detail::walk(op, function_ref(callback), + Order); + } + if (ItOrder == IteratorOrder::Reverse) { + return detail::walk(op, function_ref(callback), + Order); + } + llvm_unreachable("unsupported iterator order"); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order' (post-order by default). This method is selected for +/// given operation. The order in which regions, blocks and operations at +/// the same nesting are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'IteratorOrder'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order' (post-order by default). This method is selected for /// void-returning callbacks that operate on a specific derived operation type. /// A callback on an operation is allowed to erase that operation only if the /// walk is in post-order. See non-void method for pre-order erasure. @@ -183,7 +221,8 @@ /// Example: /// op->walk([](ReturnOp op) { ... }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + WalkOrder Order = WalkOrder::PostOrder, + IteratorOrder ItOrder = IteratorOrder::Forward, typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t< @@ -195,17 +234,26 @@ if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walk(op, function_ref(wrapperFn), Order); + if (ItOrder == IteratorOrder::Forward) { + return detail::walk( + op, function_ref(wrapperFn), Order); + } + if (ItOrder == IteratorOrder::Reverse) { + return detail::walk( + op, function_ref(wrapperFn), Order); + } + llvm_unreachable("unsupported iterator order"); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. Regions, blocks and operations at the same nesting level -/// are visited in lexicographical order. The walk order for enclosing regions, -/// blocks and operations with respect to their nested ones is specified by -/// 'Order' (post-order by default). This method is selected for WalkReturn -/// returning skippable or interruptible callbacks that operate on a specific -/// derived operation type. A callback on an operation is allowed to erase that -/// operation if either: +/// given operation. The order in which regions, blocks and operations at +/// the same nesting are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'IteratorOrder'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'Order' (post-order by default). This method is selected for +/// WalkReturn returning skippable or interruptible callbacks that operate on a +/// specific derived operation type. A callback on an operation is allowed to +/// erase that operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// @@ -218,7 +266,8 @@ /// return WalkResult::advance(); /// }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + WalkOrder Order = WalkOrder::PostOrder, + IteratorOrder ItOrder = IteratorOrder::Forward, typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t< @@ -231,7 +280,15 @@ return callback(derivedOp); return WalkResult::advance(); }; - return detail::walk(op, function_ref(wrapperFn), Order); + if (ItOrder == IteratorOrder::Forward) { + return detail::walk( + op, function_ref(wrapperFn), Order); + } + if (ItOrder == IteratorOrder::Reverse) { + return detail::walk( + op, function_ref(wrapperFn), Order); + } + llvm_unreachable("unsupported iterator order"); } /// Generic walkers with stage aware callbacks. diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -15,60 +15,86 @@ : numRegions(op->getNumRegions()), nextRegion(0) {} /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. Regions, blocks and operations at the same nesting level -/// are visited in lexicographical order. The walk order for enclosing regions, -/// blocks and operations with respect to their nested ones is specified by -/// 'order'. These methods are invoked for void-returning callbacks. A callback -/// on a block or operation is allowed to erase that block or operation only if -/// the walk is in post-order. See non-void method for pre-order erasure. +/// given operation. The order in which regions, blocks and operations at the +/// same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. These methods are invoked for void-returning +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation only if the walk is in post-order. See non-void method for +/// pre-order erasure. +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { // We don't use early increment for regions because they can't be erased from // a callback. - for (auto ®ion : op->getRegions()) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { if (order == WalkOrder::PreOrder) callback(®ion); - for (auto &block : region) { - for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + for (auto &block : Iterator::makeRange(region)) { + for (auto &nestedOp : Iterator::makeRange(block)) + walk(&nestedOp, callback, order); } if (order == WalkOrder::PostOrder) callback(®ion); } } - +// Explicit template instantiations for all supported iterators. +template void +detail::walk(Operation *, function_ref, + WalkOrder); +template void +detail::walk(Operation *, function_ref, + WalkOrder); + +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { - for (auto ®ion : op->getRegions()) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { // Early increment here in the case where the block is erased. - for (auto &block : llvm::make_early_inc_range(region)) { + for (auto &block : + llvm::make_early_inc_range(Iterator::makeRange(region))) { if (order == WalkOrder::PreOrder) callback(&block); - for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + for (auto &nestedOp : Iterator::makeRange(block)) + walk(&nestedOp, callback, order); if (order == WalkOrder::PostOrder) callback(&block); } } } - +// Explicit template instantiations for all supported iterators. +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *, + function_ref, + WalkOrder); + +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { if (order == WalkOrder::PreOrder) callback(op); // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { + for (auto &block : Iterator::makeRange(region)) { // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) - walk(&nestedOp, callback, order); + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeRange(block))) + walk(&nestedOp, callback, order); } } if (order == WalkOrder::PostOrder) callback(op); } +// Explicit template instantiations for all supported iterators. +template void detail::walk( + Operation *, function_ref, WalkOrder); +template void detail::walk( + Operation *, function_ref, WalkOrder); void detail::walk(Operation *op, function_ref callback) { @@ -99,12 +125,13 @@ /// operation is allowed to erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { // We don't use early increment for regions because they can't be erased from // a callback. - for (auto ®ion : op->getRegions()) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(®ion); if (result.wasSkipped()) @@ -112,9 +139,9 @@ if (result.wasInterrupted()) return WalkResult::interrupt(); } - for (auto &block : region) { - for (auto &nestedOp : block) - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &block : Iterator::makeRange(region)) { + for (auto &nestedOp : Iterator::makeRange(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); } if (order == WalkOrder::PostOrder) { @@ -126,13 +153,20 @@ } return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { - for (auto ®ion : op->getRegions()) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { // Early increment here in the case where the block is erased. - for (auto &block : llvm::make_early_inc_range(region)) { + for (auto &block : + llvm::make_early_inc_range(Iterator::makeRange(region))) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(&block); if (result.wasSkipped()) @@ -140,8 +174,8 @@ if (result.wasInterrupted()) return WalkResult::interrupt(); } - for (auto &nestedOp : block) - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &nestedOp : Iterator::makeRange(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); if (order == WalkOrder::PostOrder) { if (callback(&block).wasInterrupted()) @@ -153,7 +187,13 @@ } return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { @@ -167,11 +207,12 @@ } // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { + for (auto ®ion : Iterator::makeRange(op->getRegions())) { + for (auto &block : Iterator::makeRange(region)) { // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) { - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeRange(block))) { + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); } } @@ -181,6 +222,11 @@ return callback(op); return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); +template WalkResult detail::walk( + Operation *, function_ref, WalkOrder); WalkResult detail::walk( Operation *op,