diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -252,31 +252,34 @@ /// Walk the operations in this block in postorder, calling the callback for /// each operation. /// See Operation::walk for more details. - template > + template > RetT walk(FnT &&callback) { - return walk(begin(), end(), std::forward(callback)); + return walk(begin(), end(), std::forward(callback)); } /// Walk the operations in the specified [begin, end) range of this block in /// postorder, calling the callback for each operation. This method is invoked /// for void return callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walk(&op, callback); + detail::walk(&op, callback); } /// Walk the operations in the specified [begin, end) range of this block in /// postorder, calling the callback for each operation. This method is invoked /// for interruptible callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walk(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -169,9 +169,10 @@ /// Walk the operation in postorder, calling the callback for each nested /// operation(including this one). /// See Operation::walk for more details. - template > + template > RetT walk(FnT &&callback) { - return state->walk(std::forward(callback)); + return state->walk(std::forward(callback)); } // These are default implementations of customization hooks. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -496,9 +496,10 @@ /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); - template > + template > RetT walk(FnT &&callback) { - return detail::walk(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -245,21 +245,23 @@ /// Walk the operations in this region in postorder, calling the callback for /// each operation. This method is invoked for void-returning callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - block.walk(callback); + block.walk(callback); } /// Walk the operations in this region in postorder, calling the callback for /// each operation. This method is invoked for interruptible callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - if (block.walk(callback).wasInterrupted()) + if (block.walk(callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -65,15 +65,21 @@ /// Walk all of the regions, blocks, or operations nested under (and including) /// the given operation. +template void walk(Operation *op, function_ref callback); +template void walk(Operation *op, function_ref callback); +template void walk(Operation *op, function_ref callback); /// Walk all of the regions, blocks, or operations nested under (and including) /// the given operation. These functions walk until an interrupt result is /// returned by the callback. +template WalkResult walk(Operation *op, function_ref callback); +template WalkResult walk(Operation *op, function_ref callback); +template WalkResult walk(Operation *op, function_ref callback); // Below are a set of functions to walk nested operations. Users should favor @@ -90,12 +96,13 @@ /// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + bool IsPreOrder = false, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< llvm::is_one_of::value, RetT>::type walk(Operation *op, FuncTy &&callback) { - return walk(op, function_ref(callback)); + return walk(op, function_ref(callback)); } /// Walk all of the operations of type 'ArgT' nested under and including the @@ -105,7 +112,8 @@ /// Example: /// op->walk([](ReturnOp op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + bool IsPreOrder = false, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -116,7 +124,8 @@ if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, + function_ref(wrapperFn)); } /// Walk all of the operations of type 'ArgT' nested under and including the @@ -130,7 +139,8 @@ /// return WalkResult::advance(); /// }); template < - typename FuncTy, typename ArgT = detail::first_argument, + bool IsPreOrder = false, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -142,7 +152,8 @@ return callback(derivedOp); return WalkResult::advance(); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, + function_ref(wrapperFn)); } /// Utility to provide the return type of a templated walk method. diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -13,41 +13,57 @@ /// Walk all of the regions/blocks/operations nested under and including the /// given operation. + +template void detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { callback(®ion); for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } } +template void detail::walk(Operation *, function_ref); +template void detail::walk(Operation *, function_ref); +template void detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { for (auto &block : region) { callback(&block); for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } } +template void detail::walk(Operation *, function_ref); +template void detail::walk(Operation *, function_ref); +template void detail::walk(Operation *op, function_ref callback) { + if (IsPreOrder) + callback(op); + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } - callback(op); + + if (!IsPreOrder) + callback(op); } +template void detail::walk(Operation *, function_ref); +template void detail::walk(Operation *, function_ref); /// Walk all of the regions/blocks/operations nested under and including the /// given operation. These functions walk operations until an interrupt result /// is returned by the callback. +template WalkResult detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { @@ -55,12 +71,17 @@ return WalkResult::interrupt(); for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } return WalkResult::advance(); } +template WalkResult detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { @@ -68,23 +89,39 @@ if (callback(&block).wasInterrupted()) return WalkResult::interrupt(); for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } return WalkResult::advance(); } +template WalkResult detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *op, function_ref callback) { + if (IsPreOrder) + if (callback(op).wasInterrupted()) + return WalkResult::interrupt(); + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) { - if (walk(&nestedOp, callback).wasInterrupted()) + if (walk(&nestedOp, callback).wasInterrupted()) return WalkResult::interrupt(); } } } - return callback(op); + + if (!IsPreOrder) + return callback(op); + return WalkResult::advance(); } +template WalkResult detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *, + function_ref);