Index: mlir/examples/standalone/CMakeLists.txt =================================================================== --- mlir/examples/standalone/CMakeLists.txt +++ mlir/examples/standalone/CMakeLists.txt @@ -52,4 +52,5 @@ endif() add_subdirectory(test) add_subdirectory(standalone-opt) +add_subdirectory(standalone-plugin) add_subdirectory(standalone-translate) Index: mlir/examples/standalone/include/Standalone/CMakeLists.txt =================================================================== --- mlir/examples/standalone/include/Standalone/CMakeLists.txt +++ mlir/examples/standalone/include/Standalone/CMakeLists.txt @@ -1,3 +1,7 @@ add_mlir_dialect(StandaloneOps standalone) add_mlir_doc(StandaloneDialect StandaloneDialect Standalone/ -gen-dialect-doc) add_mlir_doc(StandaloneOps StandaloneOps Standalone/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS StandalonePasses.td) +mlir_tablegen(StandalonePasses.h.inc --gen-pass-decls) +add_public_tablegen_target(MLIRStandalonePassesIncGen) Index: mlir/examples/standalone/include/Standalone/StandalonePasses.h =================================================================== --- /dev/null +++ mlir/examples/standalone/include/Standalone/StandalonePasses.h @@ -0,0 +1,26 @@ +//===- StandalonePasses.h - Standalone passes ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef STANDALONE_STANDALONEPASSES_H +#define STANDALONE_STANDALONEPASSES_H + +#include "Standalone/StandaloneDialect.h" +#include "Standalone/StandaloneOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace standalone { +#define GEN_PASS_DECL +#include "Standalone/StandalonePasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "Standalone/StandalonePasses.h.inc" +} // namespace standalone +} // namespace mlir + +#endif Index: mlir/examples/standalone/include/Standalone/StandalonePasses.td =================================================================== --- /dev/null +++ mlir/examples/standalone/include/Standalone/StandalonePasses.td @@ -0,0 +1,30 @@ +//===- StandalonePsss.td - Standalone dialect passes -------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef STANDALONE_PASS +#define STANDALONE_PASS + +include "mlir/Pass/PassBase.td" + +def StandaloneSwitchBarFoo: Pass<"standalone-switch-bar-foo", "::mlir::ModuleOp"> { + let summary = "Switches the name of a FuncOp named `bar` to `foo` and folds."; + let description = [{ + Switches the name of a FuncOp named `bar` to `foo` and folds. + ``` + func.func @bar() { + return + } + // Gets transformed to: + func.func @foo() { + return + } + ``` + }]; +} + +#endif // STANDALONE_PASS Index: mlir/examples/standalone/lib/Standalone/CMakeLists.txt =================================================================== --- mlir/examples/standalone/lib/Standalone/CMakeLists.txt +++ mlir/examples/standalone/lib/Standalone/CMakeLists.txt @@ -2,14 +2,17 @@ StandaloneTypes.cpp StandaloneDialect.cpp StandaloneOps.cpp + StandalonePasses.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/Standalone DEPENDS MLIRStandaloneOpsIncGen + MLIRStandalonePassesIncGen LINK_LIBS PUBLIC MLIRIR MLIRInferTypeOpInterface + MLIRFuncDialect ) Index: mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp =================================================================== --- /dev/null +++ mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp @@ -0,0 +1,49 @@ +//===- StandalonePasses.cpp - Standalone passes -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Standalone/StandalonePasses.h" + +namespace mlir { +namespace standalone { +#define GEN_PASS_DEF_STANDALONESWITCHBARFOO +#include "Standalone/StandalonePasses.h.inc" + +class StandaloneSwitchBarFooRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const final { + if (op.getSymName() == "bar") { + rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); }); + return success(); + } + return failure(); + } +}; + +class StandaloneSwitchBarFoo + : public impl::StandaloneSwitchBarFooBase { +public: + using impl::StandaloneSwitchBarFooBase< + StandaloneSwitchBarFoo>::StandaloneSwitchBarFooBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + auto result = applyPatternsAndFoldGreedily(getOperation(), patternSet); + if (result.failed()) + assert(false && "StandaloneSwitchBarFooRewriter failed."); + } +}; +} // namespace standalone +} // namespace mlir Index: mlir/examples/standalone/standalone-opt/standalone-opt.cpp =================================================================== --- mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -21,9 +21,11 @@ #include "llvm/Support/ToolOutputFile.h" #include "Standalone/StandaloneDialect.h" +#include "Standalone/StandalonePasses.h" int main(int argc, char **argv) { mlir::registerAllPasses(); + mlir::standalone::registerPasses(); // TODO: Register standalone passes here. mlir::DialectRegistry registry; Index: mlir/examples/standalone/standalone-plugin/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/examples/standalone/standalone-plugin/CMakeLists.txt @@ -0,0 +1,21 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + MLIRIR + MLIRPass + MLIRStandalone + MLIRTransformUtils + ) + +add_mlir_dialect_library(StandalonePlugin + SHARED + standalone-plugin.cpp + + DEPENDS + MLIRStandalone + ) + +llvm_update_compile_flags(StandalonePlugin) +target_link_libraries(StandalonePlugin PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(StandalonePlugin) Index: mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp =================================================================== --- /dev/null +++ mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp @@ -0,0 +1,47 @@ +//===- standalone-plugin.cpp ------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectPlugin.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" + +#include "Standalone/StandaloneDialect.h" +#include "Standalone/StandalonePasses.h" + +using namespace mlir; + +/// Dialect plugin registration mechanism. +/// Observe that it also allows to register passes. +DialectPluginLibraryInfo getStandaloneDialect() { + return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING, + [](DialectRegistry ®istry) { + registry.insert(); + mlir::standalone::registerPasses(); + }}; +} + +/// Necessary symbol to register the dialect plugin. +extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo +mlirGetDialectPluginInfo() { + return getStandaloneDialect(); +} + +/// Pass plugin registration mechanism. +PassPluginLibraryInfo getStandalonePasses() { + return {MLIR_PLUGIN_API_VERSION, "StandalonePasses", LLVM_VERSION_STRING, + []() { mlir::standalone::registerPasses(); }}; +} + +/// Necessary symbol to register the pass plugin. +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() { + return getStandalonePasses(); +} Index: mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir =================================================================== --- /dev/null +++ mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --load-pass-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s + +module { + // CHECK-LABEL: func @foo() + func.func @bar() { + return + } + + // CHECK-LABEL: func @abar() + func.func @abar() { + return + } +} Index: mlir/examples/standalone/test/Standalone/standalone-plugin.mlir =================================================================== --- /dev/null +++ mlir/examples/standalone/test/Standalone/standalone-plugin.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --load-dialect-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s + +module { + // CHECK-LABEL: func @foo() + func.func @bar() { + return + } + + // CHECK-LABEL: func @standalone_types(%arg0: !standalone.custom<"10">) + func.func @standalone_types(%arg0: !standalone.custom<"10">) { + return + } +} Index: mlir/examples/standalone/test/lit.cfg.py =================================================================== --- mlir/examples/standalone/test/lit.cfg.py +++ mlir/examples/standalone/test/lit.cfg.py @@ -44,12 +44,16 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.standalone_obj_root, 'test') config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin') +config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib') + +config.substitutions.append(('%standalone_libs', config.standalone_libs_dir)) # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir] tools = [ + 'mlir-opt', 'standalone-capi-test', 'standalone-opt', 'standalone-translate', Index: mlir/include/mlir/IR/DialectPlugin.h =================================================================== --- /dev/null +++ mlir/include/mlir/IR/DialectPlugin.h @@ -0,0 +1,106 @@ +//===- mlir/IR/DialectPlugin.h - Public Plugin API -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines the public entry point for dialect plugins. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTPLUGIN_H +#define MLIR_IR_DIALECTPLUGIN_H + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/PassPlugin.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include +#include + +namespace mlir { +extern "C" { +/// Information about the plugin required to load its dialects & passes +/// +/// This struct defines the core interface for dialect plugins and is supposed +/// to be filled out by plugin implementors. MLIR-side users of a plugin are +/// expected to use the \c DialectPlugin class below to interface with it. +struct DialectPluginLibraryInfo { + /// The API version understood by this plugin, usually \c + /// MLIR_PLUGIN_API_VERSION + uint32_t APIVersion; + /// A meaningful name of the plugin. + const char *PluginName; + /// The version of the plugin. + const char *PluginVersion; + + /// The callback for registering dialect plugin with a \c DialectRegistry + /// instance + void (*RegisterDialectRegistryCallbacks)(DialectRegistry &); +}; +} + +/// A loaded dialect plugin. +/// +/// An instance of this class wraps a loaded dialect plugin and gives access to +/// its interface defined by the \c DialectPluginLibraryInfo it exposes. +class DialectPlugin { +public: + /// Attempts to load a dialect plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or loaded, + /// there is no public entry point, or the plugin implements the wrong API + /// version. + static llvm::Expected Load(const std::string &Filename); + + /// Get the filename of the loaded plugin. + StringRef getFilename() const { return Filename; } + + /// Get the plugin name + StringRef getPluginName() const { return Info.PluginName; } + + /// Get the plugin version + StringRef getPluginVersion() const { return Info.PluginVersion; } + + /// Get the plugin API version + uint32_t getAPIVersion() const { return Info.APIVersion; } + + /// Invoke the DialectRegistry callback registration + void + RegisterDialectRegistryCallbacks(DialectRegistry &dialectRegistry) const { + Info.RegisterDialectRegistryCallbacks(dialectRegistry); + } + +private: + DialectPlugin(const std::string &Filename, + const llvm::sys::DynamicLibrary &Library) + : Filename(Filename), Library(Library), Info() {} + + std::string Filename; + llvm::sys::DynamicLibrary Library; + DialectPluginLibraryInfo Info; +}; +} // namespace mlir + +/// The public entry point for a dialect plugin. +/// +/// When a plugin is loaded by the driver, it will call this entry point to +/// obtain information about this plugin and about how to register its dialects. +/// This function needs to be implemented by the plugin, see the example below: +/// +/// ``` +/// extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +/// mlirGetDialectPluginInfo() { +/// return { +/// MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", [](DialectRegistry) { ... } +/// }; +/// } +/// ``` +extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +mlirGetDialectPluginInfo(); + +#endif /* MLIR_IR_DIALECTPLUGIN_H */ Index: mlir/include/mlir/Pass/PassPlugin.h =================================================================== --- /dev/null +++ mlir/include/mlir/Pass/PassPlugin.h @@ -0,0 +1,112 @@ +//===- mlir/Pass/PassPlugin.h - Public Plugin API -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines the public entry point for pass plugins. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PASSES_PASSPLUGIN_H +#define MLIR_PASSES_PASSPLUGIN_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include +#include + +namespace mlir { +/// \macro MLIR_PLUGIN_API_VERSION +/// Identifies the API version understood by this plugin. +/// +/// When a plugin is loaded, the driver will check it's supported plugin version +/// against that of the plugin. A mismatch is an error. The supported version +/// will be incremented for ABI-breaking changes to the \c PassPluginLibraryInfo +/// struct, i.e. when callbacks are added, removed, or reordered. +#define MLIR_PLUGIN_API_VERSION 1 + +extern "C" { +/// Information about the plugin required to load its passes +/// +/// This struct defines the core interface for pass plugins and is supposed to +/// be filled out by plugin implementors. LLVM-side users of a plugin are +/// expected to use the \c PassPlugin class below to interface with it. +struct PassPluginLibraryInfo { + /// The API version understood by this plugin, usually \c + /// MLIR_PLUGIN_API_VERSION + uint32_t APIVersion; + /// A meaningful name of the plugin. + const char *PluginName; + /// The version of the plugin. + const char *PluginVersion; + + /// The callback for registering plugin passes. + void (*RegisterPassRegistryCallbacks)(); +}; +} + +/// A loaded pass plugin. +/// +/// An instance of this class wraps a loaded pass plugin and gives access to +/// its interface defined by the \c PassPluginLibraryInfo it exposes. +class PassPlugin { +public: + /// Attempts to load a pass plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or loaded, + /// there is no public entry point, or the plugin implements the wrong API + /// version. + static llvm::Expected Load(const std::string &Filename); + + /// Get the filename of the loaded plugin. + StringRef getFilename() const { return Filename; } + + /// Get the plugin name + StringRef getPluginName() const { return Info.PluginName; } + + /// Get the plugin version + StringRef getPluginVersion() const { return Info.PluginVersion; } + + /// Get the plugin API version + uint32_t getAPIVersion() const { return Info.APIVersion; } + + /// Invoke the PassRegistry callback registration + void RegisterPassRegistryCallbacks() const { + Info.RegisterPassRegistryCallbacks(); + } + +private: + PassPlugin(const std::string &Filename, + const llvm::sys::DynamicLibrary &Library) + : Filename(Filename), Library(Library), Info() {} + + std::string Filename; + llvm::sys::DynamicLibrary Library; + PassPluginLibraryInfo Info; +}; +} // namespace mlir + +/// The public entry point for a pass plugin. +/// +/// When a plugin is loaded by the driver, it will call this entry point to +/// obtain information about this plugin and about how to register its passes. +/// This function needs to be implemented by the plugin, see the example below: +/// +/// ``` +/// extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +/// mlirGetPassPluginInfo() { +/// return { +/// MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", []() { ... } +/// }; +/// } +/// ``` +extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +mlirGetPassPluginInfo(); + +#endif /* MLIR_PASS_PASSPLUGIN_H */ Index: mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h =================================================================== --- mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -35,12 +35,12 @@ /// supported options. /// The API is fluent, and the options are sorted in alphabetical order below. /// The options can be exposed to the LLVM command line by registering them -/// with `MlirOptMainConfig::registerCLOptions();` and creating a config using -/// `auto config = MlirOptMainConfig::createFromCLOptions();`. +/// with `MlirOptMainConfig::registerCLOptions(DialectRegistry &);` and creating +/// a config using `auto config = MlirOptMainConfig::createFromCLOptions();`. class MlirOptMainConfig { public: /// Register the options as global LLVM command line options. - static void registerCLOptions(); + static void registerCLOptions(DialectRegistry &dialectRegistry); /// Create a new config with the default set from the CL options. static MlirOptMainConfig createFromCLOptions(); Index: mlir/lib/IR/CMakeLists.txt =================================================================== --- mlir/lib/IR/CMakeLists.txt +++ mlir/lib/IR/CMakeLists.txt @@ -14,6 +14,7 @@ BuiltinTypeInterfaces.cpp Diagnostics.cpp Dialect.cpp + DialectPlugin.cpp DialectResourceBlobManager.cpp Dominance.cpp ExtensibleDialect.cpp Index: mlir/lib/IR/DialectPlugin.cpp =================================================================== --- /dev/null +++ mlir/lib/IR/DialectPlugin.cpp @@ -0,0 +1,53 @@ +//===- lib/IR/DialectPlugin.cpp - Load Dialect Plugins --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/DialectPlugin.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +llvm::Expected DialectPlugin::Load(const std::string &Filename) { + std::string Error; + auto Library = + llvm::sys::DynamicLibrary::getPermanentLibrary(Filename.c_str(), &Error); + if (!Library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + Filename + "': " + Error, + llvm::inconvertibleErrorCode()); + + DialectPlugin P{Filename, Library}; + + // mlirGetDialectPluginInfo should be resolved to the definition from the + // plugin we are currently loading. + intptr_t getDetailsFn = + (intptr_t)Library.getAddressOfSymbol("mlirGetDialectPluginInfo"); + + if (!getDetailsFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + Filename, + llvm::inconvertibleErrorCode()); + + P.Info = + reinterpret_cast(getDetailsFn)(); + + if (P.Info.APIVersion != MLIR_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + Filename + "'. Got version " + + Twine(P.Info.APIVersion) + ", supported version is " + + Twine(MLIR_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!P.Info.RegisterDialectRegistryCallbacks) + return llvm::make_error( + Twine("Empty entry callback in plugin '") + Filename + "'.'", + llvm::inconvertibleErrorCode()); + + return P; +} Index: mlir/lib/Pass/CMakeLists.txt =================================================================== --- mlir/lib/Pass/CMakeLists.txt +++ mlir/lib/Pass/CMakeLists.txt @@ -3,6 +3,7 @@ Pass.cpp PassCrashRecovery.cpp PassManagerOptions.cpp + PassPlugin.cpp PassRegistry.cpp PassStatistics.cpp PassTiming.cpp Index: mlir/lib/Pass/PassPlugin.cpp =================================================================== --- /dev/null +++ mlir/lib/Pass/PassPlugin.cpp @@ -0,0 +1,52 @@ +//===- lib/Passes/PassPlugin.cpp - Load Plugins for PR Passes ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/PassPlugin.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +llvm::Expected PassPlugin::Load(const std::string &Filename) { + std::string Error; + auto Library = + llvm::sys::DynamicLibrary::getPermanentLibrary(Filename.c_str(), &Error); + if (!Library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + Filename + "': " + Error, + llvm::inconvertibleErrorCode()); + + PassPlugin P{Filename, Library}; + + // mlirGetPassPluginInfo should be resolved to the definition from the plugin + // we are currently loading. + intptr_t getDetailsFn = + (intptr_t)Library.getAddressOfSymbol("mlirGetPassPluginInfo"); + + if (!getDetailsFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + Filename, + llvm::inconvertibleErrorCode()); + + P.Info = reinterpret_cast(getDetailsFn)(); + + if (P.Info.APIVersion != MLIR_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + Filename + "'. Got version " + + Twine(P.Info.APIVersion) + ", supported version is " + + Twine(MLIR_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!P.Info.RegisterPassRegistryCallbacks) + return llvm::make_error( + Twine("Empty entry callback in plugin '") + Filename + "'.'", + llvm::inconvertibleErrorCode()); + + return P; +} Index: mlir/lib/Tools/mlir-opt/MlirOptMain.cpp =================================================================== --- mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -21,11 +21,13 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectPlugin.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassPlugin.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" @@ -101,15 +103,41 @@ cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); + static cl::list passPlugins( + "load-pass-plugin", cl::desc("Load passes from plugin library")); + /// Set the callback to load a pass plugin. + passPlugins.setCallback([&](const std::string &PluginPath) { + auto Plugin = PassPlugin::Load(PluginPath); + if (!Plugin) { + errs() << "Failed to load passes from '" << PluginPath + << "'. Request ignored.\n"; + return; + } + Plugin.get().RegisterPassRegistryCallbacks(); + }); + + static cl::list dialectPlugins( + "load-dialect-plugin", cl::desc("Load dialects from plugin library")); + this->dialectPlugins = std::addressof(dialectPlugins); + static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p"); setPassPipelineParser(passPipeline); } + + /// Set the callback to load a dialect plugin. + void setDialectPluginsCallback(DialectRegistry ®istry); + + /// Pointer to static dialectPlugins variable in constructor, needed by + /// setDialectPluginsCallback(DialectRegistry&). + cl::list *dialectPlugins = nullptr; }; } // namespace ManagedStatic clOptionsConfig; -void MlirOptMainConfig::registerCLOptions() { *clOptionsConfig; } +void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) { + clOptionsConfig->setDialectPluginsCallback(registry); +} MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() { return *clOptionsConfig; @@ -134,6 +162,19 @@ return *this; } +void MlirOptMainConfigCLOptions::setDialectPluginsCallback( + DialectRegistry ®istry) { + dialectPlugins->setCallback([&](const std::string &PluginPath) { + auto Plugin = DialectPlugin::Load(PluginPath); + if (!Plugin) { + errs() << "Failed to load dialect plugin from '" << PluginPath + << "'. Request ignored.\n"; + return; + }; + Plugin.get().RegisterDialectRegistryCallbacks(registry); + }); +} + /// Set the ExecutionContext on the context and handle the observers. class InstallDebugHandler { public: @@ -364,7 +405,7 @@ InitLLVM y(argc, argv); // Register any command line options. - MlirOptMainConfig::registerCLOptions(); + MlirOptMainConfig::registerCLOptions(registry); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerPassManagerCLOptions(); Index: mlir/tools/mlir-opt/CMakeLists.txt =================================================================== --- mlir/tools/mlir-opt/CMakeLists.txt +++ mlir/tools/mlir-opt/CMakeLists.txt @@ -79,8 +79,10 @@ DEPENDS ${LIBS} + SUPPORT_PLUGINS ) target_link_libraries(mlir-opt PRIVATE ${LIBS}) llvm_update_compile_flags(mlir-opt) mlir_check_all_link_libraries(mlir-opt) +export_executable_symbols_for_plugins(mlir-opt)