diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -249,34 +249,40 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operations in this block in postorder, calling the callback for - /// each operation. + /// Walk the operations in this block, calling the callback for each + /// operation. The walk order for regions, blocks and operations is specified + /// by 'Order' (post-order by default). /// See Operation::walk for more details. - template > + template > RetT walk(FnT &&callback) { - return walk(begin(), end(), std::forward(callback)); + return walk(begin(), end(), std::forward(callback)); } - /// Walk the operations in the specified [begin, end) range of this block in - /// postorder, calling the callback for each operation. This method is invoked - /// for void return callbacks. + /// Walk the operations in the specified [begin, end) range of this block, + /// calling the callback for each operation. The walk order for regions, + /// blocks and operations is specified by 'Order' (post-order by default). + /// This method is invoked for void return callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walk(&op, callback); + detail::walk(&op, callback); } - /// Walk the operations in the specified [begin, end) range of this block in - /// postorder, calling the callback for each operation. This method is invoked - /// for interruptible callbacks. + /// Walk the operations in the specified [begin, end) range of this block, + /// calling the callback for each operation. The walk order for regions, + /// blocks and operations is specified by 'Order' (post-order by default). + /// This method is invoked for interruptible callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walk(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -166,12 +166,14 @@ /// handlers that may be listening. InFlightDiagnostic emitRemark(const Twine &message = {}); - /// Walk the operation in postorder, calling the callback for each nested - /// operation(including this one). + /// Walk the operation by calling the callback for each nested + /// operation(including this one). The walk order for regions, blocks and + /// operations is specified by 'Order' (post-order by default). /// See Operation::walk for more details. - template > + template > RetT walk(FnT &&callback) { - return state->walk(std::forward(callback)); + return state->walk(std::forward(callback)); } // These are default implementations of customization hooks. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -481,9 +481,10 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operation in postorder, calling the callback for each nested - /// operation(including this one). The callback method can take any of the - /// following forms: + /// Walk the operation by calling the callback for each nested operation + /// (including this one). The walk order for regions, blocks and operations is + /// specified by 'Order' (post-order by default). The callback method can take + /// any of the following forms: /// void(Operation*) : Walk all operations opaquely. /// * op->walk([](Operation *nestedOp) { ...}); /// void(OpT) : Walk all operations of the given derived type. @@ -496,9 +497,10 @@ /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); - template > + template > RetT walk(FnT &&callback) { - return detail::walk(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -243,23 +243,29 @@ //===--------------------------------------------------------------------===// /// Walk the operations in this region in postorder, calling the callback for - /// each operation. This method is invoked for void-returning callbacks. + /// each operation. The walk order for regions, blocks and operations is + /// specified by 'Order' (post-order by default). This method is invoked for + /// void-returning callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - block.walk(callback); + block.walk(callback); } /// Walk the operations in this region in postorder, calling the callback for - /// each operation. This method is invoked for interruptible callbacks. + /// each operation. The walk order for regions, blocks and operations is + /// specified by 'Order' (post-order by default). This method is invoked for + /// interruptible callbacks. /// See Operation::walk for more details. - template > + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - if (block.walk(callback).wasInterrupted()) + if (block.walk(callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -49,6 +49,9 @@ bool wasInterrupted() const { return result == Interrupt; } }; +/// Traversal order for region, block and operation walk utilities. +enum class WalkOrder { PreOrder, PostOrder }; + namespace detail { /// Helper templates to deduce the first argument of a callback parameter. template Arg first_argument_type(Ret (*)(Arg)); @@ -64,16 +67,22 @@ using first_argument = decltype(first_argument_type(std::declval())); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. +/// the given operation. The walk order is specified by 'Order'. +template void walk(Operation *op, function_ref callback); +template void walk(Operation *op, function_ref callback); +template void walk(Operation *op, function_ref callback); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. These functions walk until an interrupt result is -/// returned by the callback. +/// the given operation. The walk order is specified by 'Order'. These functions +/// walk until an interrupt result is returned by the callback. +template WalkResult walk(Operation *op, function_ref callback); +template WalkResult walk(Operation *op, function_ref callback); +template WalkResult walk(Operation *op, function_ref callback); // Below are a set of functions to walk nested operations. Users should favor @@ -82,7 +91,8 @@ // upon the type of the callback function. /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. This method is selected for callbacks that operate on +/// the given operation. The walk order is specified by 'Order' (post-order +/// by default). This method is selected for callbacks that operate on /// Region*, Block*, and Operation*. /// /// Example: @@ -90,22 +100,25 @@ /// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< llvm::is_one_of::value, RetT>::type walk(Operation *op, FuncTy &&callback) { - return walk(op, function_ref(callback)); + return walk(op, function_ref(callback)); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. This method is selected for void returning callbacks that -/// operate on a specific derived operation type. +/// given operation. The walk order for regions, blocks and operations is +/// specified by 'Order' (post-order by default). This method is selected for +/// void returning callbacks that operate on a specific derived operation type. /// /// Example: /// op->walk([](ReturnOp op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -116,12 +129,14 @@ if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, function_ref(wrapperFn)); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. This method is selected for WalkReturn returning -/// interruptible callbacks that operate on a specific derived operation type. +/// given operation. The walk order for regions, blocks and operations is +/// specified by 'Order' (post-order by default). This method is selected for +/// WalkReturn returning interruptible callbacks that operate on a specific +/// derived operation type. /// /// Example: /// op->walk([](ReturnOp op) { @@ -130,7 +145,8 @@ /// return WalkResult::advance(); /// }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -142,7 +158,7 @@ return callback(derivedOp); return WalkResult::advance(); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, function_ref(wrapperFn)); } /// Utility to provide the return type of a templated walk method. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -130,7 +130,7 @@ DenseMap &builders) { llvm::SetVector toProcess; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { BlockInfoBuilder &builder = builders.try_emplace(block, block).first->second; @@ -270,7 +270,7 @@ DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { blockIds.insert({block, blockIds.size()}); for (BlockArgument argument : block->getArguments()) valueIds.insert({argument, valueIds.size()}); @@ -304,7 +304,7 @@ }; // Dump information about in and out values. - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { os << "// - Block: " << blockIds[block] << "\n"; const auto *liveness = getLiveness(block); os << "// --- LiveIn: "; diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp --- a/mlir/lib/Analysis/NumberOfExecutions.cpp +++ b/mlir/lib/Analysis/NumberOfExecutions.cpp @@ -115,7 +115,7 @@ /// Creates a new NumberOfExecutions analysis that computes how many times a /// block within a region is executed for all associated regions. NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) { - operation->walk([&](Region *region) { + operation->walk([&](Region *region) { computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution); }); } @@ -191,7 +191,7 @@ raw_ostream &os, Region *perEntryOfThisRegion) const { unsigned blockId = 0; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { llvm::errs() << "Block: " << blockId++ << "\n"; llvm::errs() << "Number of executions: "; if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion)) @@ -203,7 +203,7 @@ void NumberOfExecutions::printOperationExecutions( raw_ostream &os, Region *perEntryOfThisRegion) const { - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { block->walk([&](Operation *operation) { // Skip the operation that was used to build the analysis. if (operation == this->operation) diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -12,79 +12,143 @@ using namespace mlir; /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. +/// given operation. The walk order is specified by 'Order'. + +template void detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { - callback(®ion); + if (Order == WalkOrder::PreOrder) + callback(®ion); for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } + if (Order == WalkOrder::PostOrder) + callback(®ion); } } +template void detail::walk(Operation *, + function_ref); +template void detail::walk(Operation *, + function_ref); +template void detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { for (auto &block : region) { - callback(&block); + if (Order == WalkOrder::PreOrder) + callback(&block); for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); + if (Order == WalkOrder::PostOrder) + callback(&block); } } } +template void detail::walk(Operation *, + function_ref); +template void detail::walk(Operation *, + function_ref); +template void detail::walk(Operation *op, function_ref callback) { + if (Order == WalkOrder::PreOrder) + callback(op); + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } } - callback(op); + + if (Order == WalkOrder::PostOrder) + callback(op); } +template void +detail::walk(Operation *, function_ref); +template void +detail::walk(Operation *, + function_ref); /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. These functions walk operations until an interrupt result -/// is returned by the callback. +/// given operation. The walk order is specified by 'Order'. These functions +/// walk operations until an interrupt result is returned by the callback. +template WalkResult detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { - if (callback(®ion).wasInterrupted()) - return WalkResult::interrupt(); + if (Order == WalkOrder::PreOrder) + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); } + if (Order == WalkOrder::PostOrder) + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); } return WalkResult::advance(); } +template WalkResult +detail::walk(Operation *, + function_ref); +template WalkResult +detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *op, function_ref callback) { for (auto ®ion : op->getRegions()) { for (auto &block : region) { - if (callback(&block).wasInterrupted()) - return WalkResult::interrupt(); + if (Order == WalkOrder::PreOrder) + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback); + if (Order == WalkOrder::PostOrder) + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); } } return WalkResult::advance(); } +template WalkResult +detail::walk(Operation *, + function_ref); +template WalkResult +detail::walk(Operation *, + function_ref); +template WalkResult detail::walk(Operation *op, function_ref callback) { + if (Order == WalkOrder::PreOrder) + if (callback(op).wasInterrupted()) + return WalkResult::interrupt(); + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) { - if (walk(&nestedOp, callback).wasInterrupted()) + if (walk(&nestedOp, callback).wasInterrupted()) return WalkResult::interrupt(); } } } - return callback(op); + + if (Order == WalkOrder::PostOrder) + return callback(op); + return WalkResult::advance(); } +template WalkResult +detail::walk(Operation *, + function_ref); +template WalkResult +detail::walk(Operation *, + function_ref);