diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -258,42 +258,47 @@ /// Walk the operations in this block. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'Iterator'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > RetT walk(FnT &&callback) { - return walk(begin(), end(), std::forward(callback)); + return walk(begin(), end(), std::forward(callback)); } /// Walk the operations in the specified [begin, end) range of this block. The /// callback method is called for each nested region, block or operation, - /// depending on the callback provided. Regions, blocks and operations at the - /// same nesting level are visited in lexicographical order. The walk order + /// depending on the callback provided. The order in which regions, blocks and + /// operations at the same nesting level are visited (e.g., lexicographical or + /// reverse lexicographical order) is determined by 'Iterator'. The walk order /// for enclosing regions, blocks and operations with respect to their nested /// ones is specified by 'Order' (post-order by default). This method is /// invoked for void-returning callbacks. A callback on a block or operation /// is allowed to erase that block or operation only if the walk is in /// post-order. See non-void method for pre-order erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walk(&op, callback); + detail::walk(&op, callback); } /// Walk the operations in the specified [begin, end) range of this block. The /// callback method is called for each nested region, block or operation, - /// depending on the callback provided. Regions, blocks and operations at the - /// same nesting level are visited in lexicographical order. The walk order + /// depending on the callback provided. The order in which regions, blocks and + /// operations at the same nesting level are visited (e.g., lexicographical or + /// reverse lexicographical order) is determined by 'Iterator'. The walk order /// for enclosing regions, blocks and operations with respect to their nested /// ones is specified by 'Order' (post-order by default). This method is /// invoked for skippable or interruptible callbacks. A callback on a block or @@ -301,12 +306,13 @@ /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walk(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -132,20 +132,22 @@ /// Walk the operation by calling the callback for each nested operation /// (including this one), block or region, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations the same nesting level + /// are visited (e.g., lexicographical or reverse lexicographical order) is + /// determined by 'Iterator'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t>::num_args == 1, RetT> walk(FnT &&callback) { - return state->walk(std::forward(callback)); + return state->walk(std::forward(callback)); } /// Generic walker with a stage aware callback. Walk the operation by calling diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -607,9 +607,10 @@ /// Walk the operation by calling the callback for each nested operation /// (including this one), block or region, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'Iterator'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, or @@ -631,12 +632,13 @@ /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); - template > std::enable_if_t>::num_args == 1, RetT> walk(FnT &&callback) { - return detail::walk(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } /// Generic walker with a stage aware callback. Walk the operation by calling diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -265,37 +265,41 @@ /// Walk the operations in this region. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'Iterator'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). This method is invoked for void-returning /// callbacks. A callback on a block or operation is allowed to erase that /// block or operation only if the walk is in post-order. See non-void method /// for pre-order erasure. See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(FnT &&callback) { for (auto &block : *this) - block.walk(callback); + block.walk(callback); } /// Walk the operations in this region. The callback method is called for each /// nested region, block or operation, depending on the callback provided. - /// Regions, blocks and operations at the same nesting level are visited in - /// lexicographical order. The walk order for enclosing regions, blocks and - /// operations with respect to their nested ones is specified by 'Order' + /// The order in which regions, blocks and operations at the same nesting + /// level are visited (e.g., lexicographical or reverse lexicographical order) + /// is determined by 'Iterator'. The walk order for enclosing regions, blocks + /// and operations with respect to their nested ones is specified by 'Order' /// (post-order by default). This method is invoked for skippable or /// interruptible callbacks. A callback on a block or operation is allowed to /// erase that block or operation if either: /// * the walk is in post-order, /// * or the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. - template > std::enable_if_t::value, RetT> walk(FnT &&callback) { for (auto &block : *this) - if (block.walk(callback).wasInterrupted()) + if (block.walk(callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -62,6 +62,24 @@ /// Traversal order for region, block and operation walk utilities. enum class WalkOrder { PreOrder, PostOrder }; +/// This iterator enumerates the elements in "forward" order. +struct ForwardIterator { + template + static constexpr RangeT &makeRange(RangeT &range) { + return range; + } +}; + +/// This iterator enumerates elements in "reverse" order. It is a wrapper around +/// llvm::reverse. +struct ReverseIterator { + template + static constexpr auto makeRange(RangeT &&range) { + // llvm::reverse uses RangeT::rbegin and RangeT::rend. + return llvm::reverse(std::forward(range)); + } +}; + /// A utility class to encode the current walk stage for "generic" walkers. /// When walking an operation, we can either choose a Pre/Post order walker /// which invokes the callback on an operation before/after all its attached @@ -113,31 +131,39 @@ using first_argument = decltype(first_argument_type(std::declval())); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order'. These methods are invoked for void-returning +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. These methods are invoked for void-returning /// callbacks. A callback on a block or operation is allowed to erase that block /// or operation only if the walk is in post-order. See non-void method for /// pre-order erasure. +template void walk(Operation *op, function_ref callback, WalkOrder order); +template void walk(Operation *op, function_ref callback, WalkOrder order); +template void walk(Operation *op, function_ref callback, WalkOrder order); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order'. This method is invoked for skippable or interruptible -/// callbacks. A callback on a block or operation is allowed to erase that block -/// or operation if either: +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. This method is invoked for skippable or +/// interruptible callbacks. A callback on a block or operation is allowed to +/// erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); +template WalkResult walk(Operation *op, function_ref callback, WalkOrder order); @@ -147,10 +173,11 @@ // upon the type of the callback function. /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'Order' (post-order by default). A callback on a block or +/// the given operation. The order in which regions, blocks and operations at +/// the same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'Order' (post-order by default). A callback on a block or /// operation is allowed to erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. @@ -162,20 +189,21 @@ /// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, - typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, + typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t::value, RetT> walk(Operation *op, FuncTy &&callback) { - return detail::walk(op, function_ref(callback), Order); + return detail::walk(op, function_ref(callback), Order); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. Regions, blocks and operations at the same nesting -/// level are visited in lexicographical order. The walk order for enclosing -/// regions, blocks and operations with respect to their nested ones is -/// specified by 'order' (post-order by default). This method is selected for +/// given operation. The order in which regions, blocks and operations at +/// the same nesting are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order' (post-order by default). This method is selected for /// void-returning callbacks that operate on a specific derived operation type. /// A callback on an operation is allowed to erase that operation only if the /// walk is in post-order. See non-void method for pre-order erasure. @@ -183,8 +211,8 @@ /// Example: /// op->walk([](ReturnOp op) { ... }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, - typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, + typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t< !llvm::is_one_of::value && @@ -195,17 +223,19 @@ if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walk(op, function_ref(wrapperFn), Order); + return detail::walk(op, function_ref(wrapperFn), + Order); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. Regions, blocks and operations at the same nesting level -/// are visited in lexicographical order. The walk order for enclosing regions, -/// blocks and operations with respect to their nested ones is specified by -/// 'Order' (post-order by default). This method is selected for WalkReturn -/// returning skippable or interruptible callbacks that operate on a specific -/// derived operation type. A callback on an operation is allowed to erase that -/// operation if either: +/// given operation. The order in which regions, blocks and operations at +/// the same nesting are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'Order' (post-order by default). This method is selected for +/// WalkReturn returning skippable or interruptible callbacks that operate on a +/// specific derived operation type. A callback on an operation is allowed to +/// erase that operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. /// @@ -218,8 +248,8 @@ /// return WalkResult::advance(); /// }); template < - WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, - typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, + typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> std::enable_if_t< !llvm::is_one_of::value && @@ -231,7 +261,8 @@ return callback(derivedOp); return WalkResult::advance(); }; - return detail::walk(op, function_ref(wrapperFn), Order); + return detail::walk(op, function_ref(wrapperFn), + Order); } /// Generic walkers with stage aware callbacks. diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -15,60 +15,91 @@ : numRegions(op->getNumRegions()), nextRegion(0) {} /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. Regions, blocks and operations at the same nesting level -/// are visited in lexicographical order. The walk order for enclosing regions, -/// blocks and operations with respect to their nested ones is specified by -/// 'order'. These methods are invoked for void-returning callbacks. A callback -/// on a block or operation is allowed to erase that block or operation only if -/// the walk is in post-order. See non-void method for pre-order erasure. +/// given operation. The order in which regions, blocks and operations at the +/// same nesting level are visited (e.g., lexicographical or reverse +/// lexicographical order) is determined by 'Iterator'. The walk order for +/// enclosing regions, blocks and operations with respect to their nested ones +/// is specified by 'order'. These methods are invoked for void-returning +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation only if the walk is in post-order. See non-void method for +/// pre-order erasure. +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { // We don't use early increment for regions because they can't be erased from // a callback. - for (auto ®ion : op->getRegions()) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { if (order == WalkOrder::PreOrder) callback(®ion); - for (auto &block : region) { - for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + for (auto &block : Iterator::makeRange(region)) { + for (auto &nestedOp : Iterator::makeRange(block)) + walk(&nestedOp, callback, order); } if (order == WalkOrder::PostOrder) callback(®ion); } } +// Explicit template instantiations for all supported iterators. +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { - for (auto ®ion : op->getRegions()) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { // Early increment here in the case where the block is erased. - for (auto &block : llvm::make_early_inc_range(region)) { + for (auto &block : + llvm::make_early_inc_range(Iterator::makeRange(region))) { if (order == WalkOrder::PreOrder) callback(&block); - for (auto &nestedOp : block) - walk(&nestedOp, callback, order); + for (auto &nestedOp : Iterator::makeRange(block)) + walk(&nestedOp, callback, order); if (order == WalkOrder::PostOrder) callback(&block); } } } +// Explicit template instantiations for all supported iterators. +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *op, function_ref callback, WalkOrder order) { if (order == WalkOrder::PreOrder) callback(op); // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { + for (auto &block : Iterator::makeRange(region)) { // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) - walk(&nestedOp, callback, order); + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeRange(block))) + walk(&nestedOp, callback, order); } } if (order == WalkOrder::PostOrder) callback(op); } +// Explicit template instantiations for all supported iterators. +template void detail::walk(Operation *, + function_ref, + WalkOrder); +template void detail::walk(Operation *, + function_ref, + WalkOrder); void detail::walk(Operation *op, function_ref callback) { @@ -99,12 +130,14 @@ /// operation is allowed to erase that block or operation if either: /// * the walk is in post-order, or /// * the walk is in pre-order and the walk is skipped after the erasure. +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { // We don't use early increment for regions because they can't be erased from // a callback. - for (auto ®ion : op->getRegions()) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(®ion); if (result.wasSkipped()) @@ -112,9 +145,9 @@ if (result.wasInterrupted()) return WalkResult::interrupt(); } - for (auto &block : region) { - for (auto &nestedOp : block) - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &block : Iterator::makeRange(region)) { + for (auto &nestedOp : Iterator::makeRange(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); } if (order == WalkOrder::PostOrder) { @@ -126,13 +159,23 @@ } return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult +detail::walk(Operation *, function_ref, + WalkOrder); +template WalkResult +detail::walk(Operation *, function_ref, + WalkOrder); +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { - for (auto ®ion : op->getRegions()) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { // Early increment here in the case where the block is erased. - for (auto &block : llvm::make_early_inc_range(region)) { + for (auto &block : + llvm::make_early_inc_range(Iterator::makeRange(region))) { if (order == WalkOrder::PreOrder) { WalkResult result = callback(&block); if (result.wasSkipped()) @@ -140,8 +183,8 @@ if (result.wasInterrupted()) return WalkResult::interrupt(); } - for (auto &nestedOp : block) - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &nestedOp : Iterator::makeRange(block)) + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); if (order == WalkOrder::PostOrder) { if (callback(&block).wasInterrupted()) @@ -153,7 +196,15 @@ } return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult +detail::walk(Operation *, function_ref, + WalkOrder); +template WalkResult +detail::walk(Operation *, function_ref, + WalkOrder); +template WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { @@ -167,11 +218,13 @@ } // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { + MutableArrayRef regions = op->getRegions(); + for (auto ®ion : Iterator::makeRange(regions)) { + for (auto &block : Iterator::makeRange(region)) { // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) { - if (walk(&nestedOp, callback, order).wasInterrupted()) + for (auto &nestedOp : + llvm::make_early_inc_range(Iterator::makeRange(block))) { + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); } } @@ -181,6 +234,13 @@ return callback(op); return WalkResult::advance(); } +// Explicit template instantiations for all supported iterators. +template WalkResult +detail::walk(Operation *, + function_ref, WalkOrder); +template WalkResult +detail::walk(Operation *, + function_ref, WalkOrder); WalkResult detail::walk( Operation *op,