diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1827,6 +1827,44 @@ friend RegisteredOperationName; }; +/// This is a helper class for the unregistered ops to get the similar ability +/// as registered ops while using some MLIR utilities. For example, while +/// traversing the unregistered operation, you can do things like, +/// * op->walk([](UnregisteredOp op) { ... }; +/// +/// instead of, +/// * op->walk([](Operation *op) { +/// if (op->getName() ... ) ...; +/// }); +/// +/// Note that a template argument can only take the address of object with +/// external linkage, which means we can't use the form like +/// `UnregisteredOp<"FooOp">`. we need to declare the operation name in the +/// global scope. Example: +/// namespace { +/// char opName[] = "FooOp"; +/// } +/// +/// ... +/// // The use of UnregisteredOp is valid. +/// op->walk([](UnregisteredOp op) { ... }; +/// TODO: In C++20 it supports clas type in non-type template parameters which +/// may be able to mitigate this problem. +template +class UnregisteredOp : public Op> { +public: + using Op>::Op; + + static bool classof(Operation *op) { + if (op->getRegisteredInfo()) + return false; + // TODO: It's not context aware which means it will fail if we use this + // across different MLIRContext. + static OperationName opName(name, op->getContext()); + return op->getName() == opName; + } +}; + /// This class represents the base of an operation interface. See the definition /// of `detail::Interface` for requirements on the `Traits` type. template diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -18,10 +18,10 @@ #include "llvm/ADT/STLExtras.h" namespace mlir { +class Block; class Diagnostic; class InFlightDiagnostic; class Operation; -class Block; class Region; /// A utility result that is used to signal how to proceed with an ongoing walk: diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir --- a/mlir/test/IR/visitors.mlir +++ b/mlir/test/IR/visitors.mlir @@ -18,10 +18,13 @@ } "use3"(%i) : (index) -> () } + "dummy"(%c0) ({ + "dummy"() : () -> () + }) : (index) -> () return } -// CHECK-LABEL: Op pre-order visit +// CHECK-LABEL: Op pre-order visits // CHECK: Visiting op 'builtin.module' // CHECK: Visiting op 'builtin.func' // CHECK: Visiting op 'scf.for' @@ -32,6 +35,10 @@ // CHECK: Visiting op 'use3' // CHECK: Visiting op 'std.return' +// CHECK-LABEL: Op name pre-order visits +// CHECK: Visiting op 'dummy' with 1 operand +// CHECK: Visiting op 'dummy' with 0 operand + // CHECK-LABEL: Block pre-order visits // CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.module' // CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.func' @@ -57,6 +64,10 @@ // CHECK: Visiting op 'builtin.func' // CHECK: Visiting op 'builtin.module' +// CHECK-LABEL: Op name post-order visits +// CHECK: Visiting op 'dummy' with 0 operand +// CHECK: Visiting op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order visits // CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if' // CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if' @@ -75,6 +86,10 @@ // CHECK: Erasing op 'scf.for' // CHECK: Erasing op 'std.return' +// CHECK-LABEL: Op name pre-order erasures +// CHECK: Erasing op 'dummy' with 1 operand +// CHECK-NOT: Erasing op 'dummy' with 0 operand + // CHECK-LABEL: Block pre-order erasures // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for' @@ -87,6 +102,10 @@ // CHECK: Erasing op 'scf.for' // CHECK: Erasing op 'std.return' +// CHECK-LABEL: Op name post-order erasures (skip) +// CHECK: Erasing op 'dummy' with 0 operand +// CHECK: Erasing op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order erasures (skip) // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' // CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' @@ -103,6 +122,10 @@ // CHECK: Erasing op 'builtin.func' // CHECK: Erasing op 'builtin.module' +// CHECK-LABEL: Op name post-order erasures (no skip) +// CHECK: Erasing op 'dummy' with 0 operand +// CHECK: Erasing op 'dummy' with 1 operand + // CHECK-LABEL: Block post-order erasures (no skip) // CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if' // CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if' diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp --- a/mlir/test/lib/IR/TestVisitors.cpp +++ b/mlir/test/lib/IR/TestVisitors.cpp @@ -10,6 +10,8 @@ using namespace mlir; +static const char dummyOpName[] = "dummy"; + static void printRegion(Region *region) { llvm::outs() << "region " << region->getRegionNumber() << " from operation '" << region->getParentOp()->getName() << "'"; @@ -33,6 +35,11 @@ printOperation(op); llvm::outs() << "\n"; }; + auto opName = [](UnregisteredOp op) { + llvm::outs() << "Visiting "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + }; auto blockPure = [](Block *block) { llvm::outs() << "Visiting "; printBlock(block); @@ -47,6 +54,9 @@ llvm::outs() << "Op pre-order visits" << "\n"; op->walk(opPure); + llvm::outs() << "Op name pre-order visits" + << "\n"; + op->walk(opName); llvm::outs() << "Block pre-order visits" << "\n"; op->walk(blockPure); @@ -57,6 +67,9 @@ llvm::outs() << "Op post-order visits" << "\n"; op->walk(opPure); + llvm::outs() << "Op name post-order visits" + << "\n"; + op->walk(opName); llvm::outs() << "Block post-order visits" << "\n"; op->walk(blockPure); @@ -80,6 +93,14 @@ op->erase(); return WalkResult::skip(); }; + auto skipOpNameErasure = [](UnregisteredOp op) { + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + op->dropAllUses(); + op->erase(); + return WalkResult::skip(); + }; auto skipBlockErasure = [](Block *block) { // Do not erase module and function blocks. Otherwise there wouldn't be // too much to test in pre-order. @@ -100,6 +121,12 @@ cloned->walk(skipOpErasure); cloned->erase(); + llvm::outs() << "Op name pre-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(skipOpNameErasure); + cloned->erase(); + llvm::outs() << "Block pre-order erasures (skip)" << "\n"; cloned = op->clone(); @@ -112,6 +139,12 @@ cloned->walk(skipOpErasure); cloned->erase(); + llvm::outs() << "Op name post-order erasures (skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(skipOpNameErasure); + cloned->erase(); + llvm::outs() << "Block post-order erasures (skip)" << "\n"; cloned = op->clone(); @@ -129,6 +162,13 @@ op->dropAllUses(); op->erase(); }; + auto noSkipOpNameErasure = [](UnregisteredOp op) { + llvm::outs() << "Erasing "; + printOperation(op); + llvm::outs() << " with " << op->getNumOperands() << " operand\n"; + op->dropAllUses(); + op->erase(); + }; auto noSkipBlockErasure = [](Block *block) { llvm::outs() << "Erasing "; printBlock(block); @@ -141,6 +181,11 @@ Operation *cloned = op->clone(); cloned->walk(noSkipOpErasure); + llvm::outs() << "Op name post-order erasures (no skip)" + << "\n"; + cloned = op->clone(); + cloned->walk(noSkipOpNameErasure); + llvm::outs() << "Block post-order erasures (no skip)" << "\n"; cloned = op->clone();