diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -362,6 +362,74 @@ emitPassDefs(pass, os); } +// TODO: Drop old pass declarations. +// The old pass base class is being kept until all the passes have switched to +// the new decls/defs design. +const char *const oldPassDeclBegin = R"( +template +class {0}Base : public {1} { +public: + using Base = {0}Base; + + {0}Base() : {1}(::mlir::TypeID::get()) {{} + {0}Base(const {0}Base &other) : {1}(other) {{} + + /// Returns the command-line argument attached to this pass. + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("{2}"); + } + ::llvm::StringRef getArgument() const override { return "{2}"; } + + ::llvm::StringRef getDescription() const override { return "{3}"; } + + /// Returns the derived pass name. + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("{0}"); + } + ::llvm::StringRef getName() const override { return "{0}"; } + + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const ::mlir::Pass *pass) {{ + return pass->getTypeID() == ::mlir::TypeID::get(); + } + + /// A clone method to create a copy of this pass. + std::unique_ptr<::mlir::Pass> clonePass() const override {{ + return std::make_unique(*static_cast(this)); + } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + {4} + } + + /// Explicitly declare the TypeID for this class. We declare an explicit private + /// instantiation because Pass classes should only be visible by the current + /// library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base) + +protected: +)"; + +// TODO: Drop old pass declarations. +/// Emit a backward-compatible declaration of the pass base class. +static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { + StringRef defName = pass.getDef()->getName(); + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); + for (StringRef dependentDialect : pass.getDependentDialects()) + dialectsOs << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } + os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(), + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); + emitPassOptionDecls(pass, os); + emitPassStatisticDecls(pass, os); + os << "};\n"; +} + static void emitPasses(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { std::vector passes = getPasses(recordKeeper); @@ -371,6 +439,15 @@ emitPass(pass, os); emitRegistrations(passes, os); + + // TODO: Drop old pass declarations. + // Emit the old code until all the passes have switched to the new design. + os << "// Deprecated. Please use the new per-pass macros.\n"; + os << "#ifdef GEN_PASS_CLASSES\n"; + for (const Pass &pass : passes) + emitOldPassDecl(pass, os); + os << "#undef GEN_PASS_CLASSES\n"; + os << "#endif // GEN_PASS_CLASSES\n"; } static mlir::GenRegistration