diff --git a/mlir/cmake/modules/AddMLIRPythonExtension.cmake b/mlir/cmake/modules/AddMLIRPythonExtension.cmake --- a/mlir/cmake/modules/AddMLIRPythonExtension.cmake +++ b/mlir/cmake/modules/AddMLIRPythonExtension.cmake @@ -73,7 +73,7 @@ # project separation perspective and a discussion on how to better # segment MLIR libraries needs to happen. LIBRARY_OUTPUT_DIRECTORY ${LLVM_BINARY_DIR}/python - OUTPUT_NAME "_mlirTransforms" + OUTPUT_NAME "${extname}" PREFIX "${PYTHON_MODULE_PREFIX}" SUFFIX "${PYTHON_MODULE_SUFFIX}${PYTHON_MODULE_EXTENSION}" ) diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -11,8 +11,10 @@ set(PY_SRC_FILES mlir/__init__.py mlir/ir.py + mlir/passmanager.py mlir/dialects/__init__.py mlir/dialects/std.py + mlir/transforms/__init__.py ) add_custom_target(MLIRBindingsPythonSources ALL @@ -60,3 +62,5 @@ DEPENDS MLIRBindingsPythonSources COMPONENT MLIRBindingsPythonSources) endif() + +add_subdirectory(Transforms) \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,6 +12,7 @@ #include "Globals.h" #include "IRModules.h" +#include "Pass.h" namespace py = pybind11; using namespace mlir; @@ -210,4 +211,9 @@ // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); populateIRSubmodule(irModule); + + // Define and populate PassManager submodule. + auto passModule = + m.def_submodule("passmanager", "MLIR Pass Management Bindings"); + populatePassManagerSubmodule(passModule); } diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Pass.h @@ -0,0 +1,22 @@ +//===- IRModules.h - IR Submodules of pybind module -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PASS_H +#define MLIR_BINDINGS_PYTHON_PASS_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populatePassManagerSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_PASS_H \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -0,0 +1,75 @@ +//===- Pass.cpp - Pass Management -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Pass.h" + +#include "IRModules.h" +#include "mlir-c/Pass.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +namespace { + +/// Owning Wrapper around a PassManager. +class PyPassManager { +public: + PyPassManager(MlirPassManager passManager) : passManager(passManager) {} + ~PyPassManager() { mlirPassManagerDestroy(passManager); } + MlirPassManager get() { return passManager; } + +private: + MlirPassManager passManager; +}; + +} // anonymous namespace + +/// Create the `mlir.passmanager` here. +void mlir::python::populatePassManagerSubmodule(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- + py::class_(m, "PassManager") + .def(py::init<>([](DefaultingPyMlirContext context) { + MlirPassManager passManager = + mlirPassManagerCreate(context->get()); + return new PyPassManager(passManager); + }), + py::arg("context") = py::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_static( + "parse", + [](const std::string pipeline, DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreate(context->get()); + MlirLogicalResult status = mlirParsePassPipeline( + mlirPassManagerGetAsOpPassManager(passManager), + mlirStringRefCreate(pipeline.data(), pipeline.size())); + if (mlirLogicalResultIsFailure(status)) + throw SetPyError(PyExc_ValueError, + llvm::Twine("invalid pass pipeline '") + + pipeline + "'."); + return new PyPassManager(passManager); + }, + py::arg("pipeline"), py::arg("context") = py::none(), + "Parse a textual pass-pipeline and return a top-level PassManager " + "that can be applied on a Module. Throw a ValueError if the pipeline " + "can't be parsed") + .def( + "__str__", + [](PyPassManager &self) { + MlirPassManager passManager = self.get(); + PyPrintAccumulator printAccum; + mlirPrintPassPipeline( + mlirPassManagerGetAsOpPassManager(passManager), + printAccum.getCallback(), printAccum.getUserData()); + return printAccum.join(); + }, + "Print the textual representation for this PassManager, suitable to " + "be passed to `parse` for round-tripping."); +} diff --git a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt @@ -0,0 +1,10 @@ +################################################################################ +# Build python extension +################################################################################ + +add_mlir_python_extension(MLIRTransformsBindingsPythonExtension _mlirTransforms + INSTALL_DIR + python + SOURCES + Transforms.cpp +) \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp @@ -0,0 +1,24 @@ +//===- MainModule.cpp - Main pybind module --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Transforms.h" + +#include + +namespace py = pybind11; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirTransforms, m) { + m.doc() = "MLIR Python Native Extension"; + + // Register all the passes in the Transforms library on load. + mlirRegisterTransformsPasses(); +} diff --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py --- a/mlir/lib/Bindings/Python/mlir/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "ir", + "passmanager", ] # Expose the corresponding C-Extension module with a well-known name at this @@ -38,7 +39,7 @@ # Import sub-modules. Since these may import from here, this must come after # any exported definitions. -from . import ir +from . import ir, passmanager # Add our 'dialects' parent module to the search path for implementations. _cext.globals.append_dialect_search_prefix("mlir.dialects") diff --git a/mlir/lib/Bindings/Python/mlir/passmanager.py b/mlir/lib/Bindings/Python/mlir/passmanager.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/passmanager.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Simply a wrapper around the extension module of the same name. +from . import _reexport_cext +_reexport_cext("passmanager", __name__) +del _reexport_cext diff --git a/mlir/lib/Bindings/Python/mlir/transforms/__init__.py b/mlir/lib/Bindings/Python/mlir/transforms/__init__.py new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/transforms/__init__.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Expose the corresponding C-Extension module with a well-known name at this +# level. +import _mlirTransforms as _cextTransforms + diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/pass_manager.py @@ -0,0 +1,54 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc, sys +from mlir.ir import * +from mlir.passmanager import * + +# Log everything to stderr and flush so that we have a unified stream to match +# errors emitted by MLIR to stderr. TODO: this shouldn't be needed when +# everything is plumbed. +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + +# Verify successful round-trip. +# CHECK-LABEL: TEST: testParseSuccess +def testParseSuccess(): + with Context(): + # A first import is expected to fail because the pass isn't registered + # until we import mlir.transforms + try: + pm = PassManager.parse("module(func(print-op-stats))") + # TODO: this error should be propagate to Python but the C API does not help right now. + # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline + except ValueError as e: + # CHECK: ValueError exception: invalid pass pipeline 'module(func(print-op-stats))'. + log("ValueError exception:", e) + else: + log("Exception not produced") + + # This will register the pass and round-trip should be possible now. + import mlir.transforms + pm = PassManager.parse("module(func(print-op-stats))") +# CHECK: Roundtrip: module(func(print-op-stats)) + log("Roundtrip: ", pm) +run(testParseSuccess) + +# Verify failure on unregistered pass. +# CHECK-LABEL: TEST: testParseFail +def testParseFail(): + with Context(): + try: + pm = PassManager.parse("unknown-pass") + except ValueError as e: + # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'. + log("ValueError exception:", e) + else: + log("Exception not produced") +run(testParseFail) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -100,6 +100,7 @@ if(MLIR_BINDINGS_PYTHON_ENABLED) list(APPEND MLIR_TEST_DEPENDS MLIRBindingsPythonExtension + MLIRTransformsBindingsPythonExtension ) endif() diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp --- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp +++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp @@ -58,6 +58,8 @@ /// Emit TODO static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { os << fileHeader; + os << "// Registration for the entire group\n"; + os << "void mlirRegister" << groupName << "Passes();\n\n"; for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName(); @@ -77,8 +79,21 @@ )"; +/// {0}: The name of the pass group. +const char *const passGroupRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Group Registration +//===----------------------------------------------------------------------===// + +void mlirRegister{0}Passes() {{ + register{0}Passes(); +} +)"; + static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; + os << llvm::formatv(passGroupRegistrationCode, groupName); + for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName();