diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -249,34 +249,55 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operations in this block in postorder, calling the callback for - /// each operation. - /// See Operation::walk for more details. - template > + /// Walk the operations in this block. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either the walk is in post-order or the + /// walk is in pre-order and the walk is pruned after the erasure. See + /// Operation::walk for more details. + template > RetT walk(FnT &&callback) { - return walk(begin(), end(), std::forward(callback)); + return walk(begin(), end(), std::forward(callback)); } - /// Walk the operations in the specified [begin, end) range of this block in - /// postorder, calling the callback for each operation. This method is invoked - /// for void return callbacks. - /// See Operation::walk for more details. - template > + /// Walk the operations in the specified [begin, end) range of this block. The + /// callback method is called for each nested region, block or operation, + /// depending on the callback provided. Regions, blocks and operations at the + /// same nesting level are visited in lexicographical order. The walk order + /// for enclosing regions, blocks and operations with respect to their nested + /// ones is specified by 'Order' (post-order by default). This method is + /// invoked for void-returning callbacks. A callback on a block or operation + /// is allowed to erase that block or operation only if the walk is in + /// post-order. See non-void method for pre-order erasure. See Operation::walk + /// for more details. + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - detail::walk(&op, callback); + detail::walk(&op, callback); } - /// Walk the operations in the specified [begin, end) range of this block in - /// postorder, calling the callback for each operation. This method is invoked - /// for interruptible callbacks. - /// See Operation::walk for more details. - template > + /// Walk the operations in the specified [begin, end) range of this block. The + /// callback method is called for each nested region, block or operation, + /// depending on the callback provided. Regions, blocks and operations at the + /// same nesting level are visited in lexicographical order. The walk order + /// for enclosing regions, blocks and operations with respect to their nested + /// ones is specified by 'Order' (post-order by default). This method is + /// invoked for prunable or interruptible callbacks. A callback on a block or + /// operation is allowed to erase that block or operation if either the walk + /// is in post-order or the walk is in pre-order and the walk is pruned after + /// the erasure. See Operation::walk for more details. + template > typename std::enable_if::value, RetT>::type walk(Block::iterator begin, Block::iterator end, FnT &&callback) { for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) - if (detail::walk(&op, callback).wasInterrupted()) + if (detail::walk(&op, callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -166,12 +166,19 @@ /// handlers that may be listening. InFlightDiagnostic emitRemark(const Twine &message = {}); - /// Walk the operation in postorder, calling the callback for each nested - /// operation(including this one). - /// See Operation::walk for more details. - template > + /// Walk the operation by calling the callback for each nested operation + /// (including this one), block or region, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either the walk is in post-order or the + /// walk is in pre-order and the walk is pruned after the erasure. See + /// Operation::walk for more details. + template > RetT walk(FnT &&callback) { - return state->walk(std::forward(callback)); + return state->walk(std::forward(callback)); } // These are default implementations of customization hooks. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -481,24 +481,34 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operation in postorder, calling the callback for each nested - /// operation(including this one). The callback method can take any of the - /// following forms: + /// Walk the operation by calling the callback for each nested operation + /// (including this one), block or region, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either the walk is in post-order or the + /// walk is in pre-order and the walk is pruned after the erasure. The + /// callback method can take any of the following forms: /// void(Operation*) : Walk all operations opaquely. /// * op->walk([](Operation *nestedOp) { ...}); /// void(OpT) : Walk all operations of the given derived type. /// * op->walk([](ReturnOp returnOp) { ...}); /// WalkResult(Operation*|OpT) : Walk operations, but allow for - /// interruption/cancellation. + /// interruption/pruning. /// * op->walk([](... op) { - /// // Interrupt, i.e cancel, the walk based on some invariant. + /// // Prune, i.e. skip, the walk of this op based on some invariant. /// if (some_invariant) + /// return WalkResult::prune(); + /// // Interrupt, i.e cancel, the walk based on some invariant. + /// if (another_invariant) /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); - template > + template > RetT walk(FnT &&callback) { - return detail::walk(this, std::forward(callback)); + return detail::walk(this, std::forward(callback)); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -242,24 +242,39 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operations in this region in postorder, calling the callback for - /// each operation. This method is invoked for void-returning callbacks. - /// See Operation::walk for more details. - template > + /// Walk the operations in this region. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). This method is invoked for void-returning + /// callbacks. A callback on a block or operation is allowed to erase that + /// block or operation only if the walk is in post-order. See non-void method + /// for pre-order erasure. See Operation::walk for more details. + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - block.walk(callback); + block.walk(callback); } - /// Walk the operations in this region in postorder, calling the callback for - /// each operation. This method is invoked for interruptible callbacks. - /// See Operation::walk for more details. - template > + /// Walk the operations in this region. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). This method is invoked for prunable or + /// interruptible callbacks. A callback on a block or operation is allowed to + /// erase that block or operation if either the walk is in post-order or the + /// walk is in pre-order and the walk is pruned after the erasure. See + /// Operation::walk for more details. + template > typename std::enable_if::value, RetT>::type walk(FnT &&callback) { for (auto &block : *this) - if (block.walk(callback).wasInterrupted()) + if (block.walk(callback).wasInterrupted()) return WalkResult::interrupt(); return WalkResult::advance(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -24,10 +24,15 @@ class Block; class Region; -/// A utility result that is used to signal if a walk method should be -/// interrupted or advance. +/// A utility result that is used to signal how to proceed with an ongoing walk: +/// * Interrupt: the walk will be interrupted and no more operations, regions +/// or blocks will be visited. +/// * Advance: the walk will continue. +/// * Prune: the walk of the current operation, region or block and their +/// nested elements that haven't been visited already will be skipped and will +/// continue with the next operation, region or block. class WalkResult { - enum ResultEnum { Interrupt, Advance } result; + enum ResultEnum { Interrupt, Advance, Prune } result; public: WalkResult(ResultEnum result) : result(result) {} @@ -44,11 +49,18 @@ static WalkResult interrupt() { return {Interrupt}; } static WalkResult advance() { return {Advance}; } + static WalkResult prune() { return {Prune}; } /// Returns true if the walk was interrupted. bool wasInterrupted() const { return result == Interrupt; } + + /// Returns true if the walk was pruned. + bool wasPruned() const { return result == Prune; } }; +/// Traversal order for region, block and operation walk utilities. +enum class WalkOrder { PreOrder, PostOrder }; + namespace detail { /// Helper templates to deduce the first argument of a callback parameter. template Arg first_argument_type(Ret (*)(Arg)); @@ -64,17 +76,32 @@ using first_argument = decltype(first_argument_type(std::declval())); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. -void walk(Operation *op, function_ref callback); -void walk(Operation *op, function_ref callback); -void walk(Operation *op, function_ref callback); - +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order'. These methods are invoked for void-returning +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation only if the walk is in post-order. See non-void method for +/// pre-order erasure. +void walk(Operation *op, function_ref callback, + WalkOrder order); +void walk(Operation *op, function_ref callback, WalkOrder order); +void walk(Operation *op, function_ref callback, + WalkOrder order); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. These functions walk until an interrupt result is -/// returned by the callback. -WalkResult walk(Operation *op, function_ref callback); -WalkResult walk(Operation *op, function_ref callback); -WalkResult walk(Operation *op, function_ref callback); +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order'. This method is invoked for prunable or interruptible +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation if either the walk is in post-order or the walk is in pre-order +/// and the walk is pruned after the erasure. +WalkResult walk(Operation *op, function_ref callback, + WalkOrder order); +WalkResult walk(Operation *op, function_ref callback, + WalkOrder order); +WalkResult walk(Operation *op, function_ref callback, + WalkOrder order); // Below are a set of functions to walk nested operations. Users should favor // the direct `walk` methods on the IR classes(Operation/Block/etc) over these @@ -82,30 +109,43 @@ // upon the type of the callback function. /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. This method is selected for callbacks that operate on -/// Region*, Block*, and Operation*. +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'Order' (post-order by default). A callback on a block or +/// operation is allowed to erase that block or operation if either the walk is +/// in post-order or the walk is in pre-order and the walk is pruned after the +/// erasure. This method is selected for callbacks that operate on Region*, +/// Block*, and Operation*. /// /// Example: /// op->walk([](Region *r) { ... }); /// op->walk([](Block *b) { ... }); /// op->walk([](Operation *op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< llvm::is_one_of::value, RetT>::type walk(Operation *op, FuncTy &&callback) { - return walk(op, function_ref(callback)); + return detail::walk(op, function_ref(callback), Order); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. This method is selected for void returning callbacks that -/// operate on a specific derived operation type. +/// given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order' (post-order by default). This method is selected for +/// void-returning callbacks that operate on a specific derived operation type. +/// A callback on an operation is allowed to erase that operation only if the +/// walk is in post-order. See non-void method for pre-order erasure. /// /// Example: /// op->walk([](ReturnOp op) { ... }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -116,21 +156,30 @@ if (auto derivedOp = dyn_cast(op)) callback(derivedOp); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, function_ref(wrapperFn), Order); } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. This method is selected for WalkReturn returning -/// interruptible callbacks that operate on a specific derived operation type. +/// given operation. Regions, blocks and operations at the same nesting level +/// are visited in lexicographical order. The walk order for enclosing regions, +/// blocks and operations with respect to their nested ones is specified by +/// 'Order' (post-order by default). This method is selected for WalkReturn +/// returning prunable or interruptible callbacks that operate on a specific +/// derived operation type. A callback on an operation is allowed to erase that +/// operation if either the walk is in post-order or the walk is in pre-order +/// and the walk is pruned after the erasure. /// /// Example: /// op->walk([](ReturnOp op) { /// if (some_invariant) +/// return WalkResult::prune(); +/// if (another_invariant) /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); template < - typename FuncTy, typename ArgT = detail::first_argument, + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, typename RetT = decltype(std::declval()(std::declval()))> typename std::enable_if< !llvm::is_one_of::value && @@ -142,7 +191,7 @@ return callback(derivedOp); return WalkResult::advance(); }; - return detail::walk(op, function_ref(wrapperFn)); + return detail::walk(op, function_ref(wrapperFn), Order); } /// Utility to provide the return type of a templated walk method. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -130,7 +130,7 @@ DenseMap &builders) { llvm::SetVector toProcess; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { BlockInfoBuilder &builder = builders.try_emplace(block, block).first->second; @@ -270,7 +270,7 @@ DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { blockIds.insert({block, blockIds.size()}); for (BlockArgument argument : block->getArguments()) valueIds.insert({argument, valueIds.size()}); @@ -304,7 +304,7 @@ }; // Dump information about in and out values. - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { os << "// - Block: " << blockIds[block] << "\n"; const auto *liveness = getLiveness(block); os << "// --- LiveIn: "; diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp --- a/mlir/lib/Analysis/NumberOfExecutions.cpp +++ b/mlir/lib/Analysis/NumberOfExecutions.cpp @@ -115,7 +115,7 @@ /// Creates a new NumberOfExecutions analysis that computes how many times a /// block within a region is executed for all associated regions. NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) { - operation->walk([&](Region *region) { + operation->walk([&](Region *region) { computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution); }); } @@ -191,7 +191,7 @@ raw_ostream &os, Region *perEntryOfThisRegion) const { unsigned blockId = 0; - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { llvm::errs() << "Block: " << blockId++ << "\n"; llvm::errs() << "Number of executions: "; if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion)) @@ -203,7 +203,7 @@ void NumberOfExecutions::printOperationExecutions( raw_ostream &os, Region *perEntryOfThisRegion) const { - operation->walk([&](Block *block) { + operation->walk([&](Block *block) { block->walk([&](Operation *operation) { // Skip the operation that was used to build the analysis. if (operation == this->operation) diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -12,79 +12,149 @@ using namespace mlir; /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. -void detail::walk(Operation *op, function_ref callback) { +/// given operation. Regions, blocks and operations at the same nesting level +/// are visited in lexicographical order. The walk order for enclosing regions, +/// blocks and operations with respect to their nested ones is specified by +/// 'order'. These methods are invoked for void-returning callbacks. A callback +/// on a block or operation is allowed to erase that block or operation only if +/// the walk is in post-order. See non-void method for pre-order erasure. + +void detail::walk(Operation *op, function_ref callback, + WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. for (auto ®ion : op->getRegions()) { - callback(®ion); + if (order == WalkOrder::PreOrder) + callback(®ion); for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback, order); } + if (order == WalkOrder::PostOrder) + callback(®ion); } } -void detail::walk(Operation *op, function_ref callback) { +void detail::walk(Operation *op, function_ref callback, + WalkOrder order) { for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - callback(&block); + // Early increment here in the case where the block is erased. + for (auto &block : llvm::make_early_inc_range(region)) { + if (order == WalkOrder::PreOrder) + callback(&block); for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback, order); + if (order == WalkOrder::PostOrder) + callback(&block); } } } -void detail::walk(Operation *op, function_ref callback) { +void detail::walk(Operation *op, function_ref callback, + WalkOrder order) { + if (order == WalkOrder::PreOrder) + callback(op); + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) - walk(&nestedOp, callback); + walk(&nestedOp, callback, order); } } - callback(op); + + if (order == WalkOrder::PostOrder) + callback(op); } -/// Walk all of the regions/blocks/operations nested under and including the -/// given operation. These functions walk operations until an interrupt result -/// is returned by the callback. +/// Walk all of the regions, blocks, or operations nested under (and including) +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order'. A callback on a block or operation is allowed to erase +/// that block or operation if either the walk is in post-order or the walk is +/// in pre-order and the walk is pruned after the erasure. + WalkResult detail::walk(Operation *op, - function_ref callback) { + function_ref callback, + WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. for (auto ®ion : op->getRegions()) { - if (callback(®ion).wasInterrupted()) - return WalkResult::interrupt(); + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(®ion); + if (result.wasPruned()) + continue; + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } + for (auto &block : region) { for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback, order); + } + if (order == WalkOrder::PostOrder) { + if (callback(®ion).wasInterrupted()) + return WalkResult::interrupt(); + // We don't check if this region was pruned because its walk already + // finished and the walk will continue with the next region. } } return WalkResult::advance(); } WalkResult detail::walk(Operation *op, - function_ref callback) { + function_ref callback, + WalkOrder order) { for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - if (callback(&block).wasInterrupted()) - return WalkResult::interrupt(); + // Early increment here in the case where the block is erased. + for (auto &block : llvm::make_early_inc_range(region)) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(&block); + if (result.wasPruned()) + continue; + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } for (auto &nestedOp : block) - walk(&nestedOp, callback); + walk(&nestedOp, callback, order); + if (order == WalkOrder::PostOrder) { + if (callback(&block).wasInterrupted()) + return WalkResult::interrupt(); + // We don't check if this block was pruned because its walk already + // finished and the walk will continue with the next block. + } } } return WalkResult::advance(); } WalkResult detail::walk(Operation *op, - function_ref callback) { + function_ref callback, + WalkOrder order) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(op); + if (result.wasPruned()) + return WalkResult::prune(); + if (result.wasInterrupted()) + return WalkResult::interrupt(); + } + // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { for (auto &block : region) { // Early increment here in the case where the operation is erased. for (auto &nestedOp : llvm::make_early_inc_range(block)) { - if (walk(&nestedOp, callback).wasInterrupted()) + if (walk(&nestedOp, callback, order).wasInterrupted()) return WalkResult::interrupt(); + // We don't check if this op was pruned because its walk already + // finished and the walk will continue with the next op. } } } - return callback(op); + + if (order == WalkOrder::PostOrder) + return callback(op); + return WalkResult::advance(); } diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/visitors.mlir @@ -0,0 +1,178 @@ +// RUN: mlir-opt -test-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// Verify the different configurations of IR visitors. +// Constant, yield and other terminator ops are not matched for simplicity. +// Module and function op and their immediately nested blocks are skipped in +// erasure walks so that the output includes more cases in pre-order. + +func @structured_cfg() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + scf.for %i = %c1 to %c10 step %c1 { + %cond = "use0"(%i) : (index) -> (i1) + scf.if %cond { + "use1"(%i) : (index) -> () + } else { + "use2"(%i) : (index) -> () + } + "use3"(%i) : (index) -> () + } + return +} + +// CHECK-LABEL: Op pre-order visit +// CHECK: Visiting op 'module' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'scf.for' +// CHECK: Visiting op 'use0' +// CHECK: Visiting op 'scf.if' +// CHECK: Visiting op 'use1' +// CHECK: Visiting op 'use2' +// CHECK: Visiting op 'use3' +// CHECK: Visiting op 'std.return' + +// CHECK-LABEL: Block pre-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' + +// CHECK-LABEL: Region pre-order visits +// CHECK: Visiting region 0 from operation 'module' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'scf.for' +// CHECK: Visiting region 0 from operation 'scf.if' +// CHECK: Visiting region 1 from operation 'scf.if' + +// CHECK-LABEL: Op post-order visits +// CHECK: Visiting op 'use0' +// CHECK: Visiting op 'use1' +// CHECK: Visiting op 'use2' +// CHECK: Visiting op 'scf.if' +// CHECK: Visiting op 'use3' +// CHECK: Visiting op 'scf.for' +// CHECK: Visiting op 'std.return' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'module' + +// CHECK-LABEL: Block post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' + +// CHECK-LABEL: Region post-order visits +// CHECK: Visiting region 0 from operation 'scf.if' +// CHECK: Visiting region 1 from operation 'scf.if' +// CHECK: Visiting region 0 from operation 'scf.for' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'module' + +// CHECK-LABEL: Op pre-order erasures +// CHECK: Erasing op 'scf.for' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block pre-order erasures +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' + +// CHECK-LABEL: Op post-order erasures +// CHECK: Erasing op 'use0' +// CHECK: Erasing op 'use1' +// CHECK: Erasing op 'use2' +// CHECK: Erasing op 'scf.if' +// CHECK: Erasing op 'use3' +// CHECK: Erasing op 'scf.for' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block post-order erasures +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' + +// ----- + +func @unstructured_cfg() { + "regionOp0"() ({ + ^bb0: + "op0"() : () -> () + br ^bb2 + ^bb1: + "op1"() : () -> () + br ^bb2 + ^bb2: + "op2"() : () -> () + }) : () -> () + return +} + +// CHECK-LABEL: Op pre-order visits +// CHECK: Visiting op 'module' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'std.return' + +// CHECK-LABEL: Block pre-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' + +// CHECK-LABEL: Region pre-order visits +// CHECK: Visiting region 0 from operation 'module' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'regionOp0' + +// CHECK-LABEL: Op post-order visits +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'std.return' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'module' + +// CHECK-LABEL: Block post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' + +// CHECK-LABEL: Region post-order visits +// CHECK: Visiting region 0 from operation 'regionOp0' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'module' + +// CHECK-LABEL: Op pre-order erasures +// CHECK: Erasing op 'regionOp0' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block pre-order erasures +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' + +// CHECK-LABEL: Op post-order erasures +// CHECK: Erasing op 'op0' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op1' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op2' +// CHECK: Erasing op 'regionOp0' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block post-order erasures +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -9,6 +9,7 @@ TestSlicing.cpp TestSymbolUses.cpp TestTypes.cpp + TestVisitors.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -0,0 +1,134 @@ +//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +static void printRegion(Region *region) { + llvm::outs() << "region " << region->getRegionNumber() << " from operation '" + << region->getParentOp()->getName() << "'"; +} + +static void printBlock(Block *block) { + llvm::outs() << "block "; + block->printAsOperand(llvm::outs(), /*printType=*/false); + llvm::outs() << " from "; + printRegion(block->getParent()); +} + +static void printOperation(Operation *op) { + llvm::outs() << "op '" << op->getName() << "'"; +} + +/// This pass exercises the different configurations of the IR visitors. +struct TestIRVisitorsPass + : public PassWrapper> { + void runOnOperation() override { + Operation *op = getOperation(); + + // Pure visits. + auto op_pure = [](Operation *op) { + llvm::outs() << "Visiting "; + printOperation(op); + llvm::outs() << "\n"; + }; + auto block_pure = [](Block *block) { + llvm::outs() << "Visiting "; + printBlock(block); + llvm::outs() << "\n"; + }; + auto region_pure = [](Region *region) { + llvm::outs() << "Visiting "; + printRegion(region); + llvm::outs() << "\n"; + }; + + llvm::outs() << "Op pre-order visits" + << "\n"; + op->walk(op_pure); + llvm::outs() << "Block pre-order visits" + << "\n"; + op->walk(block_pure); + llvm::outs() << "Region pre-order visits" + << "\n"; + op->walk(region_pure); + + llvm::outs() << "Op post-order visits" + << "\n"; + op->walk(op_pure); + llvm::outs() << "Block post-order visits" + << "\n"; + op->walk(block_pure); + llvm::outs() << "Region post-order visits" + << "\n"; + op->walk(region_pure); + + // Erasure visits. + auto op_erase = [](Operation *op) { + // Do not erase module and function op. Otherwise there wouldn't be too + // much to test in pre-order. + if (isa(op) || isa(op)) + return WalkResult::advance(); + + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << "\n"; + op->dropAllUses(); + op->erase(); + return WalkResult::prune(); + }; + auto block_erase = [](Block *block) { + // Do not erase module and function blocks. Otherwise there wouldn't be + // too much to test in pre-order. + Operation *parentOp = block->getParentOp(); + if (isa(parentOp) || isa(parentOp)) + return WalkResult::advance(); + + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + return WalkResult::prune(); + }; + + llvm::outs() << "Op pre-order erasures" + << "\n"; + Operation *cloned = op->clone(); + cloned->walk(op_erase); + cloned->erase(); + + llvm::outs() << "Block pre-order erasures" + << "\n"; + cloned = op->clone(); + cloned->walk(block_erase); + cloned->erase(); + + // Post-order erasures. + llvm::outs() << "Op post-order erasures" + << "\n"; + cloned = op->clone(); + cloned->walk(op_erase); + cloned->erase(); + + llvm::outs() << "Block post-order erasures" + << "\n"; + cloned = op->clone(); + cloned->walk(block_erase); + cloned->erase(); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestIRVisitorsPass() { + PassRegistration("test-ir-visitors", + "Test various visitors."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -41,6 +41,7 @@ void registerTestAllReduceLoweringPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); +void registerTestIRVisitorsPass(); void registerTestLoopPermutationPass(); void registerTestMatchers(); void registerTestPrintDefUsePass(); @@ -113,6 +114,7 @@ registerTestAllReduceLoweringPass(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); + registerTestIRVisitorsPass(); registerTestLoopPermutationPass(); registerTestMatchers(); registerTestPrintDefUsePass();