diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h --- a/mlir/include/mlir/Analysis/Liveness.h +++ b/mlir/include/mlir/Analysis/Liveness.h @@ -86,7 +86,7 @@ private: /// Initializes the internal mappings. - void build(MutableArrayRef regions); + void build(); private: /// The operation this analysis was constructed from. diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -21,6 +21,8 @@ class Diagnostic; class InFlightDiagnostic; class Operation; +class Block; +class Region; /// A utility result that is used to signal if a walk method should be /// interrupted or advance. @@ -61,12 +63,19 @@ template using first_argument = decltype(first_argument_type(std::declval())); -/// Walk all of the operations nested under and including the given operation. +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. +void walkOperations(Operation *op, function_ref callback); +void walkOperations(Operation *op, function_ref callback); void walkOperations(Operation *op, function_ref callback); -/// Walk all of the operations nested under and including the given operation. -/// This methods walks operations until an interrupt result is returned by the -/// callback. +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. These functions walk operations until an interrupt result +/// is returned by the callback. +WalkResult walkOperations(Operation *op, + function_ref callback); +WalkResult walkOperations(Operation *op, + function_ref callback); WalkResult walkOperations(Operation *op, function_ref callback); @@ -83,7 +92,10 @@ template < typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> -typename std::enable_if::value, RetT>::type +typename std::enable_if::value || + std::is_same::value || + std::is_same::value, + RetT>::type walkOperations(Operation *op, FuncTy &&callback) { return detail::walkOperations(op, function_ref(callback)); } @@ -98,6 +110,8 @@ typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value && std::is_same::value, RetT>::type walkOperations(Operation *op, FuncTy &&callback) { @@ -122,6 +136,8 @@ typename FuncTy, typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if::value && + !std::is_same::value && + !std::is_same::value && std::is_same::value, RetT>::type walkOperations(Operation *op, FuncTy &&callback) { diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -125,31 +125,17 @@ }; } // namespace -/// Walks all regions (including nested regions recursively) and invokes the -/// given function for every block. -template -static void walkRegions(MutableArrayRef regions, const FuncT &func) { - for (Region ®ion : regions) - for (Block &block : region) { - func(block); - - // Traverse all nested regions. - for (Operation &operation : block) - walkRegions(operation.getRegions(), func); - } -} - /// Builds the internal liveness block mapping. -static void buildBlockMapping(MutableArrayRef regions, +static void buildBlockMapping(Operation *operation, DenseMap &builders) { llvm::SetVector toProcess; - walkRegions(regions, [&](Block &block) { + operation->walk([&](Block *block) { BlockInfoBuilder &builder = - builders.try_emplace(&block, &block).first->second; + builders.try_emplace(block, block).first->second; if (builder.updateLiveIn()) - toProcess.insert(block.pred_begin(), block.pred_end()); + toProcess.insert(block->pred_begin(), block->pred_end()); }); // Propagate the in and out-value sets (fixpoint iteration) @@ -172,14 +158,14 @@ /// Creates a new Liveness analysis that computes liveness information for all /// associated regions. -Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); } +Liveness::Liveness(Operation *op) : operation(op) { build(); } /// Initializes the internal mappings. -void Liveness::build(MutableArrayRef regions) { +void Liveness::build() { // Build internal block mapping. DenseMap builders; - buildBlockMapping(regions, builders); + buildBlockMapping(operation, builders); // Store internal block data. for (auto &entry : builders) { @@ -284,11 +270,11 @@ DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - walkRegions(operation->getRegions(), [&](Block &block) { - blockIds.insert({&block, blockIds.size()}); - for (BlockArgument argument : block.getArguments()) + operation->walk([&](Block *block) { + blockIds.insert({block, blockIds.size()}); + for (BlockArgument argument : block->getArguments()) valueIds.insert({argument, valueIds.size()}); - for (Operation &operation : block) { + for (Operation &operation : *block) { operationIds.insert({&operation, operationIds.size()}); for (Value result : operation.getResults()) valueIds.insert({result, valueIds.size()}); @@ -318,9 +304,9 @@ }; // Dump information about in and out values. - walkRegions(operation->getRegions(), [&](Block &block) { - os << "// - Block: " << blockIds[&block] << "\n"; - auto liveness = getLiveness(&block); + operation->walk([&](Block *block) { + os << "// - Block: " << blockIds[block] << "\n"; + auto liveness = getLiveness(block); os << "// --- LiveIn: "; printValueRefs(liveness->inValues); os << "\n// --- LiveOut: "; @@ -329,7 +315,7 @@ // Print liveness intervals. os << "// --- BeginLiveness"; - for (Operation &op : block) { + for (Operation &op : *block) { if (op.getNumResults() < 1) continue; os << "\n"; diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -11,21 +11,74 @@ using namespace mlir; -/// Walk all of the operations nested under and including the given operations. +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. +void detail::walkOperations(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + callback(®ion); + for (auto &block : region) { + for (auto &nestedOp : block) + walkOperations(&nestedOp, callback); + } + } +} + +void detail::walkOperations(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + callback(&block); + for (auto &nestedOp : block) + walkOperations(&nestedOp, callback); + } + } +} + void detail::walkOperations(Operation *op, function_ref callback) { // TODO: This walk should be iterative over the operations. - for (auto ®ion : op->getRegions()) - for (auto &block : region) + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) walkOperations(&nestedOp, callback); - + } + } callback(op); } -/// Walk all of the operations nested under and including the given operations. -/// This methods walks operations until an interrupt signal is received. +/// Walk all of the regions/blocks/operations nested under and including the +/// given operation. These functions walk operations until an interrupt result +/// is returned by the callback. +WalkResult +detail::walkOperations(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); + for (auto &block : region) { + for (auto &nestedOp : block) + walkOperations(&nestedOp, callback); + } + } + return WalkResult::advance(); +} + +WalkResult +detail::walkOperations(Operation *op, + function_ref callback) { + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); + for (auto &nestedOp : block) + walkOperations(&nestedOp, callback); + } + } + return WalkResult::advance(); +} + WalkResult detail::walkOperations(Operation *op, function_ref callback) { @@ -33,9 +86,10 @@ for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. - for (auto &nestedOp : llvm::make_early_inc_range(block)) + for (auto &nestedOp : llvm::make_early_inc_range(block)) { if (walkOperations(&nestedOp, callback).wasInterrupted()) return WalkResult::interrupt(); + } } } return callback(op);