diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt --- a/mlir/examples/standalone/CMakeLists.txt +++ b/mlir/examples/standalone/CMakeLists.txt @@ -52,4 +52,5 @@ endif() add_subdirectory(test) add_subdirectory(standalone-opt) +add_subdirectory(standalone-plugin) add_subdirectory(standalone-translate) diff --git a/mlir/examples/standalone/include/Standalone/CMakeLists.txt b/mlir/examples/standalone/include/Standalone/CMakeLists.txt --- a/mlir/examples/standalone/include/Standalone/CMakeLists.txt +++ b/mlir/examples/standalone/include/Standalone/CMakeLists.txt @@ -1,3 +1,7 @@ add_mlir_dialect(StandaloneOps standalone) add_mlir_doc(StandaloneDialect StandaloneDialect Standalone/ -gen-dialect-doc) add_mlir_doc(StandaloneOps StandaloneOps Standalone/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS StandalonePasses.td) +mlir_tablegen(StandalonePasses.h.inc --gen-pass-decls) +add_public_tablegen_target(MLIRStandalonePassesIncGen) diff --git a/mlir/examples/standalone/include/Standalone/StandalonePasses.h b/mlir/examples/standalone/include/Standalone/StandalonePasses.h new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/include/Standalone/StandalonePasses.h @@ -0,0 +1,26 @@ +//===- StandalonePasses.h - Standalone passes ------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef STANDALONE_STANDALONEPASSES_H +#define STANDALONE_STANDALONEPASSES_H + +#include "Standalone/StandaloneDialect.h" +#include "Standalone/StandaloneOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace standalone { +#define GEN_PASS_DECL +#include "Standalone/StandalonePasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "Standalone/StandalonePasses.h.inc" +} // namespace standalone +} // namespace mlir + +#endif diff --git a/mlir/examples/standalone/include/Standalone/StandalonePasses.td b/mlir/examples/standalone/include/Standalone/StandalonePasses.td new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/include/Standalone/StandalonePasses.td @@ -0,0 +1,30 @@ +//===- StandalonePsss.td - Standalone dialect passes -------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef STANDALONE_PASS +#define STANDALONE_PASS + +include "mlir/Pass/PassBase.td" + +def StandaloneSwitchBarFoo: Pass<"standalone-switch-bar-foo", "::mlir::ModuleOp"> { + let summary = "Switches the name of a FuncOp named `bar` to `foo` and folds."; + let description = [{ + Switches the name of a FuncOp named `bar` to `foo` and folds. + ``` + func.func @bar() { + return + } + // Gets transformed to: + func.func @foo() { + return + } + ``` + }]; +} + +#endif // STANDALONE_PASS diff --git a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt --- a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt +++ b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt @@ -2,14 +2,17 @@ StandaloneTypes.cpp StandaloneDialect.cpp StandaloneOps.cpp + StandalonePasses.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/Standalone DEPENDS MLIRStandaloneOpsIncGen + MLIRStandalonePassesIncGen LINK_LIBS PUBLIC MLIRIR MLIRInferTypeOpInterface + MLIRFuncDialect ) diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp @@ -0,0 +1,48 @@ +//===- StandalonePasses.cpp - Standalone passes -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Standalone/StandalonePasses.h" + +namespace mlir::standalone { +#define GEN_PASS_DEF_STANDALONESWITCHBARFOO +#include "Standalone/StandalonePasses.h.inc" + +namespace { +class StandaloneSwitchBarFooRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const final { + if (op.getSymName() == "bar") { + rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); }); + return success(); + } + return failure(); + } +}; + +class StandaloneSwitchBarFoo + : public impl::StandaloneSwitchBarFooBase { +public: + using impl::StandaloneSwitchBarFooBase< + StandaloneSwitchBarFoo>::StandaloneSwitchBarFooBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + signalPassFailure(); + } +}; +} // namespace +} // namespace mlir::standalone diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -21,9 +21,11 @@ #include "llvm/Support/ToolOutputFile.h" #include "Standalone/StandaloneDialect.h" +#include "Standalone/StandalonePasses.h" int main(int argc, char **argv) { mlir::registerAllPasses(); + mlir::standalone::registerPasses(); // TODO: Register standalone passes here. mlir::DialectRegistry registry; diff --git a/mlir/examples/standalone/standalone-plugin/CMakeLists.txt b/mlir/examples/standalone/standalone-plugin/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/standalone-plugin/CMakeLists.txt @@ -0,0 +1,22 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + MLIRIR + MLIRPass + MLIRPluginsLib + MLIRStandalone + MLIRTransformUtils + ) + +add_mlir_dialect_library(StandalonePlugin + SHARED + standalone-plugin.cpp + + DEPENDS + MLIRStandalone + ) + +llvm_update_compile_flags(StandalonePlugin) +target_link_libraries(StandalonePlugin PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(StandalonePlugin) diff --git a/mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp b/mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/standalone-plugin/standalone-plugin.cpp @@ -0,0 +1,39 @@ +//===- standalone-plugin.cpp ------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/Plugins/DialectPlugin.h" + +#include "Standalone/StandaloneDialect.h" +#include "Standalone/StandalonePasses.h" + +using namespace mlir; + +/// Dialect plugin registration mechanism. +/// Observe that it also allows to register passes. +/// Necessary symbol to register the dialect plugin. +extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo +mlirGetDialectPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING, + [](DialectRegistry *registry) { + registry->insert(); + mlir::standalone::registerPasses(); + }}; +} + +/// Pass plugin registration mechanism. +/// Necessary symbol to register the pass plugin. +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "StandalonePasses", LLVM_VERSION_STRING, + []() { mlir::standalone::registerPasses(); }}; +} diff --git a/mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir b/mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/test/Standalone/standalone-pass-plugin.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --load-pass-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s + +module { + // CHECK-LABEL: func @foo() + func.func @bar() { + return + } + + // CHECK-LABEL: func @abar() + func.func @abar() { + return + } +} diff --git a/mlir/examples/standalone/test/Standalone/standalone-plugin.mlir b/mlir/examples/standalone/test/Standalone/standalone-plugin.mlir new file mode 100644 --- /dev/null +++ b/mlir/examples/standalone/test/Standalone/standalone-plugin.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s --load-dialect-plugin=%standalone_libs/libStandalonePlugin.so --pass-pipeline="builtin.module(standalone-switch-bar-foo)" | FileCheck %s + +module { + // CHECK-LABEL: func @foo() + func.func @bar() { + return + } + + // CHECK-LABEL: func @standalone_types(%arg0: !standalone.custom<"10">) + func.func @standalone_types(%arg0: !standalone.custom<"10">) { + return + } +} diff --git a/mlir/examples/standalone/test/lit.cfg.py b/mlir/examples/standalone/test/lit.cfg.py --- a/mlir/examples/standalone/test/lit.cfg.py +++ b/mlir/examples/standalone/test/lit.cfg.py @@ -44,12 +44,16 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.standalone_obj_root, 'test') config.standalone_tools_dir = os.path.join(config.standalone_obj_root, 'bin') +config.standalone_libs_dir = os.path.join(config.standalone_obj_root, 'lib') + +config.substitutions.append(('%standalone_libs', config.standalone_libs_dir)) # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) tool_dirs = [config.standalone_tools_dir, config.llvm_tools_dir] tools = [ + 'mlir-opt', 'standalone-capi-test', 'standalone-opt', 'standalone-translate', diff --git a/mlir/include/mlir/Tools/Plugins/DialectPlugin.h b/mlir/include/mlir/Tools/Plugins/DialectPlugin.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/Plugins/DialectPlugin.h @@ -0,0 +1,106 @@ +//===- mlir/Tools/Plugins/DialectPlugin.h - Public Plugin API -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines the public entry point for dialect plugins. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H +#define MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Tools/Plugins/PassPlugin.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include +#include + +namespace mlir { +extern "C" { +/// Information about the plugin required to load its dialects & passes +/// +/// This struct defines the core interface for dialect plugins and is supposed +/// to be filled out by plugin implementors. MLIR-side users of a plugin are +/// expected to use the \c DialectPlugin class below to interface with it. +struct DialectPluginLibraryInfo { + /// The API version understood by this plugin, usually + /// \c MLIR_PLUGIN_API_VERSION + uint32_t apiVersion; + /// A meaningful name of the plugin. + const char *pluginName; + /// The version of the plugin. + const char *pluginVersion; + + /// The callback for registering dialect plugin with a \c DialectRegistry + /// instance + void (*registerDialectRegistryCallbacks)(DialectRegistry *); +}; +} + +/// A loaded dialect plugin. +/// +/// An instance of this class wraps a loaded dialect plugin and gives access to +/// its interface defined by the \c DialectPluginLibraryInfo it exposes. +class DialectPlugin { +public: + /// Attempts to load a dialect plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or loaded, + /// there is no public entry point, or the plugin implements the wrong API + /// version. + static llvm::Expected load(const std::string &filename); + + /// Get the filename of the loaded plugin. + StringRef getFilename() const { return filename; } + + /// Get the plugin name + StringRef getPluginName() const { return info.pluginName; } + + /// Get the plugin version + StringRef getPluginVersion() const { return info.pluginVersion; } + + /// Get the plugin API version + uint32_t getAPIVersion() const { return info.apiVersion; } + + /// Invoke the DialectRegistry callback registration + void + registerDialectRegistryCallbacks(DialectRegistry &dialectRegistry) const { + info.registerDialectRegistryCallbacks(&dialectRegistry); + } + +private: + DialectPlugin(const std::string &filename, + const llvm::sys::DynamicLibrary &library) + : filename(filename), library(library), info() {} + + std::string filename; + llvm::sys::DynamicLibrary library; + DialectPluginLibraryInfo info; +}; +} // namespace mlir + +/// The public entry point for a dialect plugin. +/// +/// When a plugin is loaded by the driver, it will call this entry point to +/// obtain information about this plugin and about how to register its dialects. +/// This function needs to be implemented by the plugin, see the example below: +/// +/// ``` +/// extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +/// mlirGetDialectPluginInfo() { +/// return { +/// MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", [](DialectRegistry) { ... } +/// }; +/// } +/// ``` +extern "C" ::mlir::DialectPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +mlirGetDialectPluginInfo(); + +#endif /* MLIR_TOOLS_PLUGINS_DIALECTPLUGIN_H */ diff --git a/mlir/include/mlir/Tools/Plugins/PassPlugin.h b/mlir/include/mlir/Tools/Plugins/PassPlugin.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/Plugins/PassPlugin.h @@ -0,0 +1,112 @@ +//===- mlir/Tools/Plugins/PassPlugin.h - Public Plugin API ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines the public entry point for pass plugins. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PLUGINS_PASSPLUGIN_H +#define MLIR_TOOLS_PLUGINS_PASSPLUGIN_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include +#include + +namespace mlir { +/// \macro MLIR_PLUGIN_API_VERSION +/// Identifies the API version understood by this plugin. +/// +/// When a plugin is loaded, the driver will check it's supported plugin version +/// against that of the plugin. A mismatch is an error. The supported version +/// will be incremented for ABI-breaking changes to the \c PassPluginLibraryInfo +/// struct, i.e. when callbacks are added, removed, or reordered. +#define MLIR_PLUGIN_API_VERSION 1 + +extern "C" { +/// Information about the plugin required to load its passes +/// +/// This struct defines the core interface for pass plugins and is supposed to +/// be filled out by plugin implementors. LLVM-side users of a plugin are +/// expected to use the \c PassPlugin class below to interface with it. +struct PassPluginLibraryInfo { + /// The API version understood by this plugin, usually \c + /// MLIR_PLUGIN_API_VERSION + uint32_t apiVersion; + /// A meaningful name of the plugin. + const char *pluginName; + /// The version of the plugin. + const char *pluginVersion; + + /// The callback for registering plugin passes. + void (*registerPassRegistryCallbacks)(); +}; +} + +/// A loaded pass plugin. +/// +/// An instance of this class wraps a loaded pass plugin and gives access to +/// its interface defined by the \c PassPluginLibraryInfo it exposes. +class PassPlugin { +public: + /// Attempts to load a pass plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or loaded, + /// there is no public entry point, or the plugin implements the wrong API + /// version. + static llvm::Expected load(const std::string &filename); + + /// Get the filename of the loaded plugin. + StringRef getFilename() const { return filename; } + + /// Get the plugin name + StringRef getPluginName() const { return info.pluginName; } + + /// Get the plugin version + StringRef getPluginVersion() const { return info.pluginVersion; } + + /// Get the plugin API version + uint32_t getAPIVersion() const { return info.apiVersion; } + + /// Invoke the PassRegistry callback registration + void registerPassRegistryCallbacks() const { + info.registerPassRegistryCallbacks(); + } + +private: + PassPlugin(const std::string &filename, + const llvm::sys::DynamicLibrary &library) + : filename(filename), library(library), info() {} + + std::string filename; + llvm::sys::DynamicLibrary library; + PassPluginLibraryInfo info; +}; +} // namespace mlir + +/// The public entry point for a pass plugin. +/// +/// When a plugin is loaded by the driver, it will call this entry point to +/// obtain information about this plugin and about how to register its passes. +/// This function needs to be implemented by the plugin, see the example below: +/// +/// ``` +/// extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +/// mlirGetPassPluginInfo() { +/// return { +/// MLIR_PLUGIN_API_VERSION, "MyPlugin", "v0.1", []() { ... } +/// }; +/// } +/// ``` +extern "C" ::mlir::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK +mlirGetPassPluginInfo(); + +#endif /* MLIR_TOOLS_PLUGINS_PASSPLUGIN_H */ diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -35,12 +35,12 @@ /// supported options. /// The API is fluent, and the options are sorted in alphabetical order below. /// The options can be exposed to the LLVM command line by registering them -/// with `MlirOptMainConfig::registerCLOptions();` and creating a config using -/// `auto config = MlirOptMainConfig::createFromCLOptions();`. +/// with `MlirOptMainConfig::registerCLOptions(DialectRegistry &);` and creating +/// a config using `auto config = MlirOptMainConfig::createFromCLOptions();`. class MlirOptMainConfig { public: /// Register the options as global LLVM command line options. - static void registerCLOptions(); + static void registerCLOptions(DialectRegistry &dialectRegistry); /// Create a new config with the default set from the CL options. static MlirOptMainConfig createFromCLOptions(); diff --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt --- a/mlir/lib/Tools/CMakeLists.txt +++ b/mlir/lib/Tools/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) add_subdirectory(PDLL) +add_subdirectory(Plugins) add_subdirectory(tblgen-lsp-server) diff --git a/mlir/lib/Tools/Plugins/CMakeLists.txt b/mlir/lib/Tools/Plugins/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/Plugins/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library(MLIRPluginsLib + DialectPlugin.cpp + PassPlugin.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/Plugins + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRSupport + ) diff --git a/mlir/lib/Tools/Plugins/DialectPlugin.cpp b/mlir/lib/Tools/Plugins/DialectPlugin.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/Plugins/DialectPlugin.cpp @@ -0,0 +1,53 @@ +//===- lib/Tools/Plugins/DialectPlugin.cpp - Load Dialect Plugins ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/Plugins/DialectPlugin.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +llvm::Expected DialectPlugin::load(const std::string &filename) { + std::string error; + auto library = + llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); + if (!library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + filename + "': " + error, + llvm::inconvertibleErrorCode()); + + DialectPlugin plugin{filename, library}; + + // mlirGetDialectPluginInfo should be resolved to the definition from the + // plugin we are currently loading. + intptr_t getDetailsFn = + (intptr_t)library.getAddressOfSymbol("mlirGetDialectPluginInfo"); + + if (!getDetailsFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + filename, + llvm::inconvertibleErrorCode()); + + plugin.info = + reinterpret_cast(getDetailsFn)(); + + if (plugin.info.apiVersion != MLIR_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + filename + "'. Got version " + + Twine(plugin.info.apiVersion) + ", supported version is " + + Twine(MLIR_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!plugin.info.registerDialectRegistryCallbacks) + return llvm::make_error( + Twine("Empty entry callback in plugin '") + filename + "'.'", + llvm::inconvertibleErrorCode()); + + return plugin; +} diff --git a/mlir/lib/Tools/Plugins/PassPlugin.cpp b/mlir/lib/Tools/Plugins/PassPlugin.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/Plugins/PassPlugin.cpp @@ -0,0 +1,53 @@ +//===- lib/Tools/Plugins/PassPlugin.cpp - Load Plugins for PR Passes ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/Plugins/PassPlugin.h" +#include "llvm/Support/raw_ostream.h" + +#include + +using namespace mlir; + +llvm::Expected PassPlugin::load(const std::string &filename) { + std::string Error; + auto library = + llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &Error); + if (!library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + filename + "': " + Error, + llvm::inconvertibleErrorCode()); + + PassPlugin plugin{filename, library}; + + // mlirGetPassPluginInfo should be resolved to the definition from the plugin + // we are currently loading. + intptr_t getDetailsFn = + (intptr_t)library.getAddressOfSymbol("mlirGetPassPluginInfo"); + + if (!getDetailsFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + filename, + llvm::inconvertibleErrorCode()); + + plugin.info = + reinterpret_cast(getDetailsFn)(); + + if (plugin.info.apiVersion != MLIR_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + filename + "'. Got version " + + Twine(plugin.info.apiVersion) + ", supported version is " + + Twine(MLIR_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!plugin.info.registerPassRegistryCallbacks) + return llvm::make_error( + Twine("Empty entry callback in plugin '") + filename + "'.'", + llvm::inconvertibleErrorCode()); + + return plugin; +} diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -10,5 +10,6 @@ MLIRObservers MLIRPass MLIRParser + MLIRPluginsLib MLIRSupport ) diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -30,6 +30,8 @@ #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/ParseUtilities.h" +#include "mlir/Tools/Plugins/DialectPlugin.h" +#include "mlir/Tools/Plugins/PassPlugin.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/InitLLVM.h" @@ -101,15 +103,41 @@ cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); + static cl::list passPlugins( + "load-pass-plugin", cl::desc("Load passes from plugin library")); + /// Set the callback to load a pass plugin. + passPlugins.setCallback([&](const std::string &pluginPath) { + auto plugin = PassPlugin::load(pluginPath); + if (!plugin) { + errs() << "Failed to load passes from '" << pluginPath + << "'. Request ignored.\n"; + return; + } + plugin.get().registerPassRegistryCallbacks(); + }); + + static cl::list dialectPlugins( + "load-dialect-plugin", cl::desc("Load dialects from plugin library")); + this->dialectPlugins = std::addressof(dialectPlugins); + static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p"); setPassPipelineParser(passPipeline); } + + /// Set the callback to load a dialect plugin. + void setDialectPluginsCallback(DialectRegistry ®istry); + + /// Pointer to static dialectPlugins variable in constructor, needed by + /// setDialectPluginsCallback(DialectRegistry&). + cl::list *dialectPlugins = nullptr; }; } // namespace ManagedStatic clOptionsConfig; -void MlirOptMainConfig::registerCLOptions() { *clOptionsConfig; } +void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) { + clOptionsConfig->setDialectPluginsCallback(registry); +} MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() { return *clOptionsConfig; @@ -134,6 +162,19 @@ return *this; } +void MlirOptMainConfigCLOptions::setDialectPluginsCallback( + DialectRegistry ®istry) { + dialectPlugins->setCallback([&](const std::string &pluginPath) { + auto plugin = DialectPlugin::load(pluginPath); + if (!plugin) { + errs() << "Failed to load dialect plugin from '" << pluginPath + << "'. Request ignored.\n"; + return; + }; + plugin.get().registerDialectRegistryCallbacks(registry); + }); +} + /// Set the ExecutionContext on the context and handle the observers. class InstallDebugHandler { public: @@ -365,7 +406,7 @@ InitLLVM y(argc, argv); // Register any command line options. - MlirOptMainConfig::registerCLOptions(); + MlirOptMainConfig::registerCLOptions(registry); registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerPassManagerCLOptions(); diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -79,8 +79,10 @@ DEPENDS ${LIBS} + SUPPORT_PLUGINS ) target_link_libraries(mlir-opt PRIVATE ${LIBS}) llvm_update_compile_flags(mlir-opt) mlir_check_all_link_libraries(mlir-opt) +export_executable_symbols_for_plugins(mlir-opt)