diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -74,4 +74,6 @@ } #endif +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + #endif // MLIR_C_DIALECT_SPARSE_TENSOR_H diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name SparseTensor) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix SparseTensor) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix SparseTensor) add_public_tablegen_target(MLIRSparseTensorPassIncGen) add_mlir_doc(Passes SparseTensorPasses ./ -gen-pass-doc) diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -33,6 +33,14 @@ ) add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) +add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSparseTensorPasses + INSTALL_DIR + python + SOURCES + SparseTensorPasses.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension) + add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses INSTALL_DIR python diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp @@ -0,0 +1,22 @@ +//===- SparseTensorPasses.cpp - Pybind module for the SparseTensor passes -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SparseTensor.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirSparseTensorPasses, m) { + m.doc() = "MLIR SparseTensor Dialect Passes"; + + // Register all SparseTensor passes on load. + mlirRegisterSparseTensorPasses(); +} diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -62,11 +62,13 @@ add_mlir_public_c_api_library(MLIRCAPISparseTensor SparseTensor.cpp + SparseTensorPasses.cpp PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRSparseTensor + MLIRSparseTensorTransforms ) add_mlir_public_c_api_library(MLIRCAPIStandard diff --git a/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp @@ -0,0 +1,26 @@ +//===- SparseTensorPasses.cpp - C API for SparseTensor Dialect Passes -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._cext_loader import _reexport_cext +from .._cext_loader import _load_extension + _reexport_cext("dialects.sparse_tensor", __name__) +_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses") + del _reexport_cext +del _load_extension diff --git a/mlir/test/python/dialects/sparse_tensor/passes.py b/mlir/test/python/dialects/sparse_tensor/passes.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/sparse_tensor/passes.py @@ -0,0 +1,23 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.passmanager import * + +from mlir.dialects import sparse_tensor as st + + +def run(f): + print('\nTEST:', f.__name__) + f() + + +def testSparseTensorPass(): + with Context() as context: + PassManager.parse('sparsification') + PassManager.parse('sparse-tensor-conversion') + print('SUCCESS') + + +# CHECK-LABEL: testSparseTensorPass +# CHECK: SUCCESS +run(testSparseTensorPass)