diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -249,9 +249,15 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operations in this block, calling the callback for each - /// operation. The walk order for regions, blocks and operations is specified - /// by 'Order' (post-order by default). + /// Walk the operations in this block. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either: + /// * the walk is in post-order, or + /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. template > @@ -259,10 +265,15 @@ return walk(begin(), end(), std::forward(callback)); } - /// Walk the operations in the specified [begin, end) range of this block, - /// calling the callback for each operation. The walk order for regions, - /// blocks and operations is specified by 'Order' (post-order by default). - /// This method is invoked for void return callbacks. + /// Walk the operations in the specified [begin, end) range of this block. The + /// callback method is called for each nested region, block or operation, + /// depending on the callback provided. Regions, blocks and operations at the + /// same nesting level are visited in lexicographical order. The walk order + /// for enclosing regions, blocks and operations with respect to their nested + /// ones is specified by 'Order' (post-order by default). This method is + /// invoked for void-returning callbacks. A callback on a block or operation + /// is allowed to erase that block or operation only if the walk is in + /// post-order. See non-void method for pre-order erasure. /// See Operation::walk for more details. template > @@ -272,10 +283,16 @@ detail::walk(&op, callback); } - /// Walk the operations in the specified [begin, end) range of this block, - /// calling the callback for each operation. The walk order for regions, - /// blocks and operations is specified by 'Order' (post-order by default). - /// This method is invoked for interruptible callbacks. + /// Walk the operations in the specified [begin, end) range of this block. The + /// callback method is called for each nested region, block or operation, + /// depending on the callback provided. Regions, blocks and operations at the + /// same nesting level are visited in lexicographical order. The walk order + /// for enclosing regions, blocks and operations with respect to their nested + /// ones is specified by 'Order' (post-order by default). This method is + /// invoked for skippable or interruptible callbacks. A callback on a block or + /// operation is allowed to erase that block or operation if either: + /// * the walk is in post-order, or + /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. template > diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -166,9 +166,15 @@ /// handlers that may be listening. InFlightDiagnostic emitRemark(const Twine &message = {}); - /// Walk the operation by calling the callback for each nested - /// operation(including this one). The walk order for regions, blocks and - /// operations is specified by 'Order' (post-order by default). + /// Walk the operation by calling the callback for each nested operation + /// (including this one), block or region, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either: + /// * the walk is in post-order, or + /// * the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. template > diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -482,18 +482,28 @@ //===--------------------------------------------------------------------===// /// Walk the operation by calling the callback for each nested operation - /// (including this one). The walk order for regions, blocks and operations is - /// specified by 'Order' (post-order by default). The callback method can take - /// any of the following forms: + /// (including this one), block or region, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). A callback on a block or operation is allowed to + /// erase that block or operation if either: + /// * the walk is in post-order, or + /// * the walk is in pre-order and the walk is skipped after the erasure. + /// + /// The callback method can take any of the following forms: /// void(Operation*) : Walk all operations opaquely. /// * op->walk([](Operation *nestedOp) { ...}); /// void(OpT) : Walk all operations of the given derived type. /// * op->walk([](ReturnOp returnOp) { ...}); /// WalkResult(Operation*|OpT) : Walk operations, but allow for - /// interruption/cancellation. + /// interruption/skipping. /// * op->walk([](... op) { - /// // Interrupt, i.e cancel, the walk based on some invariant. + /// // Prune, i.e. skip, the walk of this op based on some invariant. /// if (some_invariant) + /// return WalkResult::skip(); + /// // Interrupt, i.e cancel, the walk based on some invariant. + /// if (another_invariant) /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -242,11 +242,15 @@ // Operation Walkers //===--------------------------------------------------------------------===// - /// Walk the operations in this region in postorder, calling the callback for - /// each operation. The walk order for regions, blocks and operations is - /// specified by 'Order' (post-order by default). This method is invoked for - /// void-returning callbacks. - /// See Operation::walk for more details. + /// Walk the operations in this region. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). This method is invoked for void-returning + /// callbacks. A callback on a block or operation is allowed to erase that + /// block or operation only if the walk is in post-order. See non-void method + /// for pre-order erasure. See Operation::walk for more details. template > typename std::enable_if::value, RetT>::type @@ -255,10 +259,16 @@ block.walk(callback); } - /// Walk the operations in this region in postorder, calling the callback for - /// each operation. The walk order for regions, blocks and operations is - /// specified by 'Order' (post-order by default). This method is invoked for - /// interruptible callbacks. + /// Walk the operations in this region. The callback method is called for each + /// nested region, block or operation, depending on the callback provided. + /// Regions, blocks and operations at the same nesting level are visited in + /// lexicographical order. The walk order for enclosing regions, blocks and + /// operations with respect to their nested ones is specified by 'Order' + /// (post-order by default). This method is invoked for skippable or + /// interruptible callbacks. A callback on a block or operation is allowed to + /// erase that block or operation if either: + /// * the walk is in post-order, + /// * or the walk is in pre-order and the walk is skipped after the erasure. /// See Operation::walk for more details. template > diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -24,10 +24,15 @@ class Block; class Region; -/// A utility result that is used to signal if a walk method should be -/// interrupted or advance. +/// A utility result that is used to signal how to proceed with an ongoing walk: +/// * Interrupt: the walk will be interrupted and no more operations, regions +/// or blocks will be visited. +/// * Advance: the walk will continue. +/// * Skip: the walk of the current operation, region or block and their +/// nested elements that haven't been visited already will be skipped and will +/// continue with the next operation, region or block. class WalkResult { - enum ResultEnum { Interrupt, Advance } result; + enum ResultEnum { Interrupt, Advance, Skip } result; public: WalkResult(ResultEnum result) : result(result) {} @@ -44,9 +49,13 @@ static WalkResult interrupt() { return {Interrupt}; } static WalkResult advance() { return {Advance}; } + static WalkResult skip() { return {Skip}; } /// Returns true if the walk was interrupted. bool wasInterrupted() const { return result == Interrupt; } + + /// Returns true if the walk was skipped. + bool wasSkipped() const { return result == Skip; } }; /// Traversal order for region, block and operation walk utilities. @@ -67,15 +76,27 @@ using first_argument = decltype(first_argument_type(std::declval())); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. The walk order is specified by 'order'. +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order'. These methods are invoked for void-returning +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation only if the walk is in post-order. See non-void method for +/// pre-order erasure. void walk(Operation *op, function_ref callback, WalkOrder order); void walk(Operation *op, function_ref callback, WalkOrder order); void walk(Operation *op, function_ref callback, WalkOrder order); /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. The walk order is specified by 'order'. These functions -/// walk until an interrupt result is returned by the callback. +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order'. This method is invoked for skippable or interruptible +/// callbacks. A callback on a block or operation is allowed to erase that block +/// or operation if either: +/// * the walk is in post-order, or +/// * the walk is in pre-order and the walk is skipped after the erasure. WalkResult walk(Operation *op, function_ref callback, WalkOrder order); WalkResult walk(Operation *op, function_ref callback, @@ -89,9 +110,15 @@ // upon the type of the callback function. /// Walk all of the regions, blocks, or operations nested under (and including) -/// the given operation. The walk order is specified by 'Order' (post-order -/// by default). This method is selected for callbacks that operate on -/// Region*, Block*, and Operation*. +/// the given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'Order' (post-order by default). A callback on a block or +/// operation is allowed to erase that block or operation if either: +/// * the walk is in post-order, or +/// * the walk is in pre-order and the walk is skipped after the erasure. +/// This method is selected for callbacks that operate on Region*, Block*, and +/// Operation*. /// /// Example: /// op->walk([](Region *r) { ... }); @@ -108,9 +135,13 @@ } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. The walk order for regions, blocks and operations is -/// specified by 'Order' (post-order by default). This method is selected for -/// void returning callbacks that operate on a specific derived operation type. +/// given operation. Regions, blocks and operations at the same nesting +/// level are visited in lexicographical order. The walk order for enclosing +/// regions, blocks and operations with respect to their nested ones is +/// specified by 'order' (post-order by default). This method is selected for +/// void-returning callbacks that operate on a specific derived operation type. +/// A callback on an operation is allowed to erase that operation only if the +/// walk is in post-order. See non-void method for pre-order erasure. /// /// Example: /// op->walk([](ReturnOp op) { ... }); @@ -131,14 +162,21 @@ } /// Walk all of the operations of type 'ArgT' nested under and including the -/// given operation. The walk order for regions, blocks and operations is -/// specified by 'Order' (post-order by default). This method is selected for -/// WalkReturn returning interruptible callbacks that operate on a specific -/// derived operation type. +/// given operation. Regions, blocks and operations at the same nesting level +/// are visited in lexicographical order. The walk order for enclosing regions, +/// blocks and operations with respect to their nested ones is specified by +/// 'Order' (post-order by default). This method is selected for WalkReturn +/// returning skippable or interruptible callbacks that operate on a specific +/// derived operation type. A callback on an operation is allowed to erase that +/// operation if either: +/// * the walk is in post-order, or +/// * the walk is in pre-order and the walk is skipped after the erasure. /// /// Example: /// op->walk([](ReturnOp op) { /// if (some_invariant) +/// return WalkResult::skip(); +/// if (another_invariant) /// return WalkResult::interrupt(); /// return WalkResult::advance(); /// }); diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -12,10 +12,16 @@ using namespace mlir; /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. The walk order is specified by 'Order'. - +/// given operation. Regions, blocks and operations at the same nesting level +/// are visited in lexicographical order. The walk order for enclosing regions, +/// blocks and operations with respect to their nested ones is specified by +/// 'order'. These methods are invoked for void-returning callbacks. A callback +/// on a block or operation is allowed to erase that block or operation only if +/// the walk is in post-order. See non-void method for pre-order erasure. void detail::walk(Operation *op, function_ref callback, WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. for (auto ®ion : op->getRegions()) { if (order == WalkOrder::PreOrder) callback(®ion); @@ -31,7 +37,8 @@ void detail::walk(Operation *op, function_ref callback, WalkOrder order) { for (auto ®ion : op->getRegions()) { - for (auto &block : region) { + // Early increment here in the case where the block is erased. + for (auto &block : llvm::make_early_inc_range(region)) { if (order == WalkOrder::PreOrder) callback(&block); for (auto &nestedOp : block) @@ -61,22 +68,38 @@ } /// Walk all of the regions/blocks/operations nested under and including the -/// given operation. The walk order is specified by 'order'. These functions -/// walk operations until an interrupt result is returned by the callback. +/// given operation. These functions walk operations until an interrupt result +/// is returned by the callback. Walks on regions, blocks and operations may +/// also be skipped if the callback returns a skip result. Regions, blocks and +/// operations at the same nesting level are visited in lexicographical order. +/// The walk order for enclosing regions, blocks and operations with respect to +/// their nested ones is specified by 'order'. A callback on a block or +/// operation is allowed to erase that block or operation if either: +/// * the walk is in post-order, or +/// * the walk is in pre-order and the walk is skipped after the erasure. WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { + // We don't use early increment for regions because they can't be erased from + // a callback. for (auto ®ion : op->getRegions()) { - if (order == WalkOrder::PreOrder) - if (callback(®ion).wasInterrupted()) + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(®ion); + if (result.wasSkipped()) + continue; + if (result.wasInterrupted()) return WalkResult::interrupt(); + } for (auto &block : region) { for (auto &nestedOp : block) walk(&nestedOp, callback, order); } - if (order == WalkOrder::PostOrder) + if (order == WalkOrder::PostOrder) { if (callback(®ion).wasInterrupted()) return WalkResult::interrupt(); + // We don't check if this region was skipped because its walk already + // finished and the walk will continue with the next region. + } } return WalkResult::advance(); } @@ -85,15 +108,23 @@ function_ref callback, WalkOrder order) { for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - if (order == WalkOrder::PreOrder) - if (callback(&block).wasInterrupted()) + // Early increment here in the case where the block is erased. + for (auto &block : llvm::make_early_inc_range(region)) { + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(&block); + if (result.wasSkipped()) + continue; + if (result.wasInterrupted()) return WalkResult::interrupt(); + } for (auto &nestedOp : block) walk(&nestedOp, callback, order); - if (order == WalkOrder::PostOrder) + if (order == WalkOrder::PostOrder) { if (callback(&block).wasInterrupted()) return WalkResult::interrupt(); + // We don't check if this block was skipped because its walk already + // finished and the walk will continue with the next block. + } } } return WalkResult::advance(); @@ -102,9 +133,14 @@ WalkResult detail::walk(Operation *op, function_ref callback, WalkOrder order) { - if (order == WalkOrder::PreOrder) - if (callback(op).wasInterrupted()) + if (order == WalkOrder::PreOrder) { + WalkResult result = callback(op); + if (result.wasSkipped()) + // Caller will continue the walk on the next operation. + return WalkResult::advance(); + if (result.wasInterrupted()) return WalkResult::interrupt(); + } // TODO: This walk should be iterative over the operations. for (auto ®ion : op->getRegions()) { diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/visitors.mlir @@ -0,0 +1,212 @@ +// RUN: mlir-opt -test-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// Verify the different configurations of IR visitors. +// Constant, yield and other terminator ops are not matched for simplicity. +// Module and function op and their immediately nested blocks are not erased in +// callbacks with return so that the output includes more cases in pre-order. + +func @structured_cfg() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + scf.for %i = %c1 to %c10 step %c1 { + %cond = "use0"(%i) : (index) -> (i1) + scf.if %cond { + "use1"(%i) : (index) -> () + } else { + "use2"(%i) : (index) -> () + } + "use3"(%i) : (index) -> () + } + return +} + +// CHECK-LABEL: Op pre-order visit +// CHECK: Visiting op 'module' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'scf.for' +// CHECK: Visiting op 'use0' +// CHECK: Visiting op 'scf.if' +// CHECK: Visiting op 'use1' +// CHECK: Visiting op 'use2' +// CHECK: Visiting op 'use3' +// CHECK: Visiting op 'std.return' + +// CHECK-LABEL: Block pre-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' + +// CHECK-LABEL: Region pre-order visits +// CHECK: Visiting region 0 from operation 'module' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'scf.for' +// CHECK: Visiting region 0 from operation 'scf.if' +// CHECK: Visiting region 1 from operation 'scf.if' + +// CHECK-LABEL: Op post-order visits +// CHECK: Visiting op 'use0' +// CHECK: Visiting op 'use1' +// CHECK: Visiting op 'use2' +// CHECK: Visiting op 'scf.if' +// CHECK: Visiting op 'use3' +// CHECK: Visiting op 'scf.for' +// CHECK: Visiting op 'std.return' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'module' + +// CHECK-LABEL: Block post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' +// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' + +// CHECK-LABEL: Region post-order visits +// CHECK: Visiting region 0 from operation 'scf.if' +// CHECK: Visiting region 1 from operation 'scf.if' +// CHECK: Visiting region 0 from operation 'scf.for' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'module' + +// CHECK-LABEL: Op pre-order erasures +// CHECK: Erasing op 'scf.for' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block pre-order erasures +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' + +// CHECK-LABEL: Op post-order erasures (skip) +// CHECK: Erasing op 'use0' +// CHECK: Erasing op 'use1' +// CHECK: Erasing op 'use2' +// CHECK: Erasing op 'scf.if' +// CHECK: Erasing op 'use3' +// CHECK: Erasing op 'scf.for' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block post-order erasures (skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' + +// CHECK-LABEL: Op post-order erasures (no skip) +// CHECK: Erasing op 'use0' +// CHECK: Erasing op 'use1' +// CHECK: Erasing op 'use2' +// CHECK: Erasing op 'scf.if' +// CHECK: Erasing op 'use3' +// CHECK: Erasing op 'scf.for' +// CHECK: Erasing op 'std.return' +// CHECK: Erasing op 'func' +// CHECK: Erasing op 'module' + +// CHECK-LABEL: Block post-order erasures (no skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' +// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' +// CHECK: Erasing block ^bb0 from region 0 from operation 'func' +// CHECK: Erasing block ^bb0 from region 0 from operation 'module' + +// ----- + +func @unstructured_cfg() { + "regionOp0"() ({ + ^bb0: + "op0"() : () -> () + br ^bb2 + ^bb1: + "op1"() : () -> () + br ^bb2 + ^bb2: + "op2"() : () -> () + }) : () -> () + return +} + +// CHECK-LABEL: Op pre-order visits +// CHECK: Visiting op 'module' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'std.return' + +// CHECK-LABEL: Block pre-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' + +// CHECK-LABEL: Region pre-order visits +// CHECK: Visiting region 0 from operation 'module' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'regionOp0' + +// CHECK-LABEL: Op post-order visits +// CHECK: Visiting op 'op0' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op1' +// CHECK: Visiting op 'std.br' +// CHECK: Visiting op 'op2' +// CHECK: Visiting op 'regionOp0' +// CHECK: Visiting op 'std.return' +// CHECK: Visiting op 'func' +// CHECK: Visiting op 'module' + +// CHECK-LABEL: Block post-order visits +// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0' +// CHECK: Visiting block ^bb0 from region 0 from operation 'func' +// CHECK: Visiting block ^bb0 from region 0 from operation 'module' + +// CHECK-LABEL: Region post-order visits +// CHECK: Visiting region 0 from operation 'regionOp0' +// CHECK: Visiting region 0 from operation 'func' +// CHECK: Visiting region 0 from operation 'module' + +// CHECK-LABEL: Op pre-order erasures (skip) +// CHECK: Erasing op 'regionOp0' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block pre-order erasures (skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' + +// CHECK-LABEL: Op post-order erasures (skip) +// CHECK: Erasing op 'op0' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op1' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op2' +// CHECK: Erasing op 'regionOp0' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block post-order erasures (skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' + +// CHECK-LABEL: Op post-order erasures (no skip) +// CHECK: Erasing op 'op0' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op1' +// CHECK: Erasing op 'std.br' +// CHECK: Erasing op 'op2' +// CHECK: Erasing op 'regionOp0' +// CHECK: Erasing op 'std.return' + +// CHECK-LABEL: Block post-order erasures (no skip) +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0' +// CHECK: Erasing block ^bb0 from region 0 from operation 'func' +// CHECK: Erasing block ^bb0 from region 0 from operation 'module' diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -9,6 +9,7 @@ TestSlicing.cpp TestSymbolUses.cpp TestTypes.cpp + TestVisitors.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -0,0 +1,169 @@ +//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +static void printRegion(Region *region) { + llvm::outs() << "region " << region->getRegionNumber() << " from operation '" + << region->getParentOp()->getName() << "'"; +} + +static void printBlock(Block *block) { + llvm::outs() << "block "; + block->printAsOperand(llvm::outs(), /*printType=*/false); + llvm::outs() << " from "; + printRegion(block->getParent()); +} + +static void printOperation(Operation *op) { + llvm::outs() << "op '" << op->getName() << "'"; +} + +/// Tests pure callbacks. +static void test_pure_callbacks(Operation *op) { + auto op_pure = [](Operation *op) { + llvm::outs() << "Visiting "; + printOperation(op); + llvm::outs() << "\n"; + }; + auto block_pure = [](Block *block) { + llvm::outs() << "Visiting "; + printBlock(block); + llvm::outs() << "\n"; + }; + auto region_pure = [](Region *region) { + llvm::outs() << "Visiting "; + printRegion(region); + llvm::outs() << "\n"; + }; + + llvm::outs() << "Op pre-order visits" + << "\n"; + op->walk(op_pure); + llvm::outs() << "Block pre-order visits" + << "\n"; + op->walk(block_pure); + llvm::outs() << "Region pre-order visits" + << "\n"; + op->walk(region_pure); + + llvm::outs() << "Op post-order visits" + << "\n"; + op->walk(op_pure); + llvm::outs() << "Block post-order visits" + << "\n"; + op->walk(block_pure); + llvm::outs() << "Region post-order visits" + << "\n"; + op->walk(region_pure); +} + +/// Tests erasure callbacks that skip the walk. +static void test_skip_erasure_callbacks(Operation *op) { + auto skip_op_erasure = [](Operation *op) { + // Do not erase module and function op. Otherwise there wouldn't be too + // much to test in pre-order. + if (isa(op) || isa(op)) + return WalkResult::advance(); + + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << "\n"; + op->dropAllUses(); + op->erase(); + return WalkResult::skip(); + }; + auto skip_block_erasure = [](Block *block) { + // Do not erase module and function blocks. Otherwise there wouldn't be + // too much to test in pre-order. + Operation *parentOp = block->getParentOp(); + if (isa(parentOp) || isa(parentOp)) + return WalkResult::advance(); + + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + return WalkResult::skip(); + }; + + llvm::outs() << "Op pre-order erasures (skip)" + << "\n"; + Operation *cloned = op->clone(); + cloned->walk(skip_op_erasure); + cloned->erase(); + + llvm::outs() << "Block pre-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(skip_block_erasure); + cloned->erase(); + + llvm::outs() << "Op post-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(skip_op_erasure); + cloned->erase(); + + llvm::outs() << "Block post-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(skip_block_erasure); + cloned->erase(); +} + +/// Tests erasure callbacks that don't skip the walk. This callbacks are only +/// valid in post-order. +static void test_noskip_erasure_callbacks(Operation *op) { + auto noskip_op_erasure = [](Operation *op) { + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << "\n"; + op->dropAllUses(); + op->erase(); + }; + auto noskip_block_erasure = [](Block *block) { + llvm::outs() << "Erasing "; + printBlock(block); + llvm::outs() << "\n"; + block->erase(); + }; + + llvm::outs() << "Op post-order erasures (no skip)" + << "\n"; + Operation *cloned = op->clone(); + cloned->walk(noskip_op_erasure); + + llvm::outs() << "Block post-order erasures (no skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(noskip_block_erasure); +} + +namespace { +/// This pass exercises the different configurations of the IR visitors. +struct TestIRVisitorsPass + : public PassWrapper> { + void runOnOperation() override { + Operation *op = getOperation(); + + test_pure_callbacks(op); + test_skip_erasure_callbacks(op); + test_noskip_erasure_callbacks(op); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestIRVisitorsPass() { + PassRegistration("test-ir-visitors", + "Test various visitors."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -41,6 +41,7 @@ void registerTestAllReduceLoweringPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); +void registerTestIRVisitorsPass(); void registerTestLoopPermutationPass(); void registerTestMatchers(); void registerTestPrintDefUsePass(); @@ -113,6 +114,7 @@ registerTestAllReduceLoweringPass(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); + registerTestIRVisitorsPass(); registerTestLoopPermutationPass(); registerTestMatchers(); registerTestPrintDefUsePass();