diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -157,10 +157,40 @@ /// See Operation::walk for more details. template > - RetT walk(FnT &&callback) { + typename std::enable_if< + llvm::function_traits>::num_args == 1, RetT>::type + walk(FnT &&callback) { return state->walk(std::forward(callback)); } + /// Generic walker with a stage aware callback. Walk the operation by calling + /// the callback for each nested operation (including this one) N+1 times, + /// where N is the number of regions attached to that operation. + /// + /// The callback method can take any of the following forms: + /// void(Operation *, const WalkStage &) : Walk all operation opaquely + /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...}); + /// void(OpT, const WalkStage &) : Walk all operations of the given derived + /// type. + /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); + /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, + /// but allow for interruption/skipping. + /// * op.walk([](... op, const WalkStage &stage) { + /// // Skip the walk of this op based on some invariant. + /// if (some_invariant) + /// return WalkResult::skip(); + /// // Interrupt, i.e cancel, the walk based on some invariant. + /// if (another_invariant) + /// return WalkResult::interrupt(); + /// return WalkResult::advance(); + /// }); + template > + typename std::enable_if< + llvm::function_traits>::num_args == 2, RetT>::type + walk(FnT &&callback) { + return state->walk(std::forward(callback)); + } + // These are default implementations of customization hooks. public: /// This hook returns any canonicalization pattern rewrites that the operation