diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -514,6 +514,12 @@ return detail::walk(this, std::forward(callback)); } + template > + RetT walk(StringRef name, FnT &&callback) { + return detail::walk(this, name, std::forward(callback)); + } + //===--------------------------------------------------------------------===// // Uses //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -16,12 +16,13 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" namespace mlir { +class Block; class Diagnostic; class InFlightDiagnostic; class Operation; -class Block; class Region; /// A utility result that is used to signal how to proceed with an ongoing walk: @@ -88,6 +89,8 @@ void walk(Operation *op, function_ref callback, WalkOrder order); void walk(Operation *op, function_ref callback, WalkOrder order); +void walk(Operation *op, StringRef name, + function_ref callback, WalkOrder order); /// Walk all of the regions, blocks, or operations nested under (and including) /// the given operation. Regions, blocks and operations at the same nesting /// level are visited in lexicographical order. The walk order for enclosing @@ -103,6 +106,9 @@ WalkOrder order); WalkResult walk(Operation *op, function_ref callback, WalkOrder order); +WalkResult walk(Operation *op, StringRef name, + function_ref callback, + WalkOrder order); // Below are a set of functions to walk nested operations. Users should favor // the direct `walk` methods on the IR classes(Operation/Block/etc) over these @@ -134,6 +140,20 @@ return detail::walk(op, function_ref(callback), Order); } +/// This method is used to walk the certain operations by specifying the name. +/// It's useful for visiting unregistered operation. +/// +/// Example: +/// op->walk("test.foo", [](Operation *op) { ... }); +template < + WalkOrder Order = WalkOrder::PostOrder, typename FuncTy, + typename ArgT = detail::first_argument, + typename RetT = decltype(std::declval()(std::declval()))> +typename std::enable_if::value, RetT>::type +walk(Operation *op, StringRef name, FuncTy &&callback) { + return detail::walk(op, name, function_ref(callback), Order); +} + /// Walk all of the operations of type 'ArgT' nested under and including the /// given operation. Regions, blocks and operations at the same nesting /// level are visited in lexicographical order. The walk order for enclosing diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -67,6 +67,16 @@ callback(op); } +void detail::walk(Operation *op, StringRef name, + function_ref callback, WalkOrder order) { + OperationName opName(name, op->getContext()); + auto wrapperFn = [&](Operation *op) { + if (op->getName() == opName) + callback(op); + }; + walk(op, wrapperFn, order); +} + /// Walk all of the regions/blocks/operations nested under and including the /// given operation. These functions walk operations until an interrupt result /// is returned by the callback. Walks on regions, blocks and operations may @@ -157,3 +167,15 @@ return callback(op); return WalkResult::advance(); } + +WalkResult detail::walk(Operation *op, StringRef name, + function_ref callback, + WalkOrder order) { + OperationName opName(name, op->getContext()); + auto wrapperFn = [&](Operation *op) -> WalkResult { + if (op->getName() == opName) + return callback(op); + return WalkResult::advance(); + }; + return walk(op, function_ref(wrapperFn), order); +} diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir --- a/mlir/test/IR/visitors.mlir +++ b/mlir/test/IR/visitors.mlir @@ -18,10 +18,13 @@ } "use3"(%i) : (index) -> () } + "dummy"(%c0) ({ + "dummy"() : () -> () + }) : (index) -> () return } -// CHECK-LABEL: Op pre-order visit +// CHECK-LABEL: Op pre-order visits // CHECK: Visiting op 'builtin.module' // CHECK: Visiting op 'builtin.func' // CHECK: Visiting op 'scf.for' @@ -32,6 +35,10 @@ // CHECK: Visiting op 'use3' // CHECK: Visiting op 'std.return' +// CHECK-LABEL: Op name pre-order visits +// CHECK: Visiting op 'dummy' with 1 operand +// CHECK: Visiting op 'dummy' with 0 operand + // CHECK-LABEL: Block pre-order visits // CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.module' // CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.func' @@ -57,6 +64,10 @@ // CHECK: Visiting op 'builtin.func' // CHECK: Visiting op 'builtin.module' +// CHECK-LABEL: Op name post-order visits +// CHECK: Visiting op 'dummy' with 0 operand +// CHECK: Visiting op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order visits // CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' // CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' @@ -75,6 +86,10 @@ // CHECK: Erasing op 'scf.for' // CHECK: Erasing op 'std.return' +// CHECK-LABEL: Op name pre-order erasures +// CHECK: Erasing op 'dummy' with 1 operand +// CHECK-NOT: Erasing op 'dummy' with 0 operand + // CHECK-LABEL: Block pre-order erasures // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' @@ -87,6 +102,10 @@ // CHECK: Erasing op 'scf.for' // CHECK: Erasing op 'std.return' +// CHECK-LABEL: Op name post-order erasures (skip) +// CHECK: Erasing op 'dummy' with 0 operand +// CHECK: Erasing op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order erasures (skip) // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' // CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' @@ -103,6 +122,10 @@ // CHECK: Erasing op 'builtin.func' // CHECK: Erasing op 'builtin.module' +// CHECK-LABEL: Op name post-order erasures (no skip) +// CHECK: Erasing op 'dummy' with 0 operand +// CHECK: Erasing op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order erasures (no skip) // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' // CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp --- a/mlir/test/lib/IR/TestVisitors.cpp +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -33,6 +33,11 @@ printOperation(op); llvm::outs() << "\n"; }; + auto opName = [](Operation *op) { + llvm::outs() << "Visiting "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + }; auto blockPure = [](Block *block) { llvm::outs() << "Visiting "; printBlock(block); @@ -47,6 +52,9 @@ llvm::outs() << "Op pre-order visits" << "\n"; op->walk(opPure); + llvm::outs() << "Op name pre-order visits" + << "\n"; + op->walk("dummy", opName); llvm::outs() << "Block pre-order visits" << "\n"; op->walk(blockPure); @@ -57,6 +65,9 @@ llvm::outs() << "Op post-order visits" << "\n"; op->walk(opPure); + llvm::outs() << "Op name post-order visits" + << "\n"; + op->walk("dummy", opName); llvm::outs() << "Block post-order visits" << "\n"; op->walk(blockPure); @@ -80,6 +91,19 @@ op->erase(); return WalkResult::skip(); }; + auto skipOpNameErasure = [](Operation *op) { + // Do not erase module and function op. Otherwise there wouldn't be too + // much to test in pre-order. + if (isa(op) || isa(op)) + return WalkResult::advance(); + + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + op->dropAllUses(); + op->erase(); + return WalkResult::skip(); + }; auto skipBlockErasure = [](Block *block) { // Do not erase module and function blocks. Otherwise there wouldn't be // too much to test in pre-order. @@ -100,6 +124,12 @@ cloned->walk(skipOpErasure); cloned->erase(); + llvm::outs() << "Op name pre-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk("dummy", skipOpNameErasure); + cloned->erase(); + llvm::outs() << "Block pre-order erasures (skip)" << "\n"; cloned = op->clone(); @@ -112,6 +142,12 @@ cloned->walk(skipOpErasure); cloned->erase(); + llvm::outs() << "Op name post-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk("dummy", skipOpNameErasure); + cloned->erase(); + llvm::outs() << "Block post-order erasures (skip)" << "\n"; cloned = op->clone(); @@ -129,6 +165,13 @@ op->dropAllUses(); op->erase(); }; + auto noSkipOpNameErasure = [](Operation *op) { + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + op->dropAllUses(); + op->erase(); + }; auto noSkipBlockErasure = [](Block *block) { llvm::outs() << "Erasing "; printBlock(block); @@ -141,6 +184,11 @@ Operation *cloned = op->clone(); cloned->walk(noSkipOpErasure); + llvm::outs() << "Op name post-order erasures (no skip)" + << "\n"; + cloned = op->clone(); + cloned->walk("dummy", noSkipOpNameErasure); + llvm::outs() << "Block post-order erasures (no skip)" << "\n"; cloned = op->clone();