diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -73,10 +73,14 @@ /// register the Affine dialect but does not need to register Linalg. virtual void getDependentDialects(DialectRegistry ®istry) const {} - /// Returns the command line argument used when registering this pass. Return + /// Return the command line argument used when registering this pass. Return /// an empty string if one does not exist. virtual StringRef getArgument() const { return ""; } + /// Return the command line description used when registering this pass. + /// Return an empty string if one does not exist. + virtual StringRef getDescription() const { return ""; } + /// Returns the name of the operation that this pass operates on, or None if /// this is a generic OperationPass. Optional getOpName() const { return opName; } diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -125,20 +125,33 @@ /// Register a specific dialect pass allocator function with the system, /// typically used through the PassRegistration template. +/// Deprecated: please use the alternate version below. void registerPass(StringRef arg, StringRef description, const PassAllocatorFunction &function); +/// Register a specific dialect pass allocator function with the system, +/// typically used through the PassRegistration template. +void registerPass(const PassAllocatorFunction &function); + /// PassRegistration provides a global initializer that registers a Pass -/// allocation routine for a concrete pass instance. The third argument is +/// allocation routine for a concrete pass instance. The argument is /// optional and provides a callback to construct a pass that does not have /// a default constructor. /// /// Usage: /// /// /// At namespace scope. -/// static PassRegistration reg("my-pass", "My Pass Description."); +/// static PassRegistration reg; /// template struct PassRegistration { + PassRegistration(const PassAllocatorFunction &constructor) { + registerPass(constructor); + } + PassRegistration() + : PassRegistration([] { return std::make_unique(); }) {} + + /// Constructor below are deprecated. + PassRegistration(StringRef arg, StringRef description, const PassAllocatorFunction &constructor) { registerPass(arg, description, constructor); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -622,11 +622,6 @@ let constructor = "mlir::createPrintOpStatsPass()"; } -def PrintOp : Pass<"print-op-graph", "ModuleOp"> { - let summary = "Print op graph per-Region"; - let constructor = "mlir::createPrintOpGraphPass()"; -} - def SCCP : Pass<"sccp"> { let summary = "Sparse Conditional Constant Propagation"; let description = [{ diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -122,6 +122,15 @@ } } +void mlir::registerPass(const PassAllocatorFunction &function) { + std::unique_ptr pass = function(); + StringRef arg = pass->getArgument(); + if (arg.empty()) + llvm::report_fatal_error( + "Trying to register a pass that does not override `getArgument()`"); + registerPass(arg, pass->getDescription(), function); +} + /// Returns the pass info for the specified pass argument or null if unknown. const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { auto it = passRegistry->find(passArg); diff --git a/mlir/test/Transforms/print-op-graph.mlir b/mlir/test/Transforms/print-op-graph.mlir --- a/mlir/test/Transforms/print-op-graph.mlir +++ b/mlir/test/Transforms/print-op-graph.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -print-op-graph %s -o %t 2>&1 | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -mlir-elide-elementsattrs-if-larger=2 -view-op-graph %s -o %t 2>&1 | FileCheck %s // CHECK-LABEL: digraph "merge_blocks" // CHECK{LITERAL}: value: [[...]] : tensor\<2x2xi32\>} diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -71,10 +71,10 @@ def testInvalidNesting(): with Context(): try: - pm = PassManager.parse("func(print-op-graph)") + pm = PassManager.parse("func(view-op-graph)") except ValueError as e: # CHECK: Can't add pass 'ViewOpGraphPass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest? - # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'. + # CHECK: ValueError exception: invalid pass pipeline 'func(view-op-graph)'. log("ValueError exception:", e) else: log("Exception not produced") diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -56,6 +56,8 @@ } ::llvm::StringRef getArgument() const override { return "{2}"; } + ::llvm::StringRef getDescription() const override { return "{3}"; } + /// Returns the derived pass name. static constexpr ::llvm::StringLiteral getPassName() { return ::llvm::StringLiteral("{0}"); @@ -74,7 +76,7 @@ /// Return the dialect that must be loaded in the context before this pass. void getDependentDialects(::mlir::DialectRegistry ®istry) const override { - {3} + {4} } protected: @@ -122,7 +124,8 @@ dependentDialect); } os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), - pass.getArgument(), dependentDialectRegistrations); + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; @@ -154,8 +157,8 @@ //===----------------------------------------------------------------------===// inline void register{0}Pass() {{ - ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{ - return {3}; + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ + return {1}; }); } )"; @@ -175,7 +178,6 @@ os << "#ifdef GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), - pass.getArgument(), pass.getSummary(), pass.getConstructor()); }