diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -138,6 +138,12 @@ add_subdirectory(lib) # C API needs all dialects for registration, but should be built before tests. add_subdirectory(lib/CAPI) + +if(MLIR_ENABLE_BINDINGS_PYTHON) + # Python sources: built extensions come in via lib/Bindings/Python + add_subdirectory(python) +endif() + if (MLIR_INCLUDE_TESTS) add_definitions(-DMLIR_INCLUDE_TESTS) add_custom_target(MLIRUnitTests) @@ -152,11 +158,6 @@ # Generally things after this point may depend on MLIR_ALL_LIBS or libMLIR.so. add_subdirectory(tools) -if(MLIR_ENABLE_BINDINGS_PYTHON) - # Python sources: built extensions come in via lib/Bindings/Python - add_subdirectory(python) -endif() - if( LLVM_INCLUDE_EXAMPLES ) add_subdirectory(examples) endif() diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Interfaces.h @@ -0,0 +1,63 @@ +//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR interface classes. It is +// intended to contain interfaces defined in lib/Interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_H +#define MLIR_C_DIALECT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Returns `true` if the given operation implements an interface identified by +/// its TypeID. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID); + +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID); + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple types from functions while +/// transferring ownerhsip to the caller. The first argument is the number of +/// consecutive elements pointed to by the second argument. The third argument +/// is an opaque pointer forwarded to the callback by the caller. +typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); + +/// Infers the return types of the operation identified by its canonical given +/// the arguments that will be supplied to its generic builder. Calls `callback` +/// with the types of inferred arguments, potentially several times, on success. +/// Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_H diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/Interfaces.h @@ -0,0 +1,18 @@ +//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR interface classes. This file should not be included from C++ code other +// than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_INTERFACES_H +#define MLIR_CAPI_INTERFACES_H + +#endif // MLIR_CAPI_INTERFACES_H diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -0,0 +1,239 @@ +//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Interfaces.h" + +namespace py = pybind11; + +namespace mlir { +namespace python { + +constexpr static const char *constructorDoc = + R"(Creates an interface from a given operation/opview object or from a +subclass of OpView. Raises ValueError if the operation does not implement the +interface.)"; + +constexpr static const char *operationDoc = + R"(Returns an Operation for which the interface was constructed.)"; + +constexpr static const char *opviewDoc = + R"(Returns an OpView subclass _instance_ for which the interface was +constructed)"; + +constexpr static const char *inferReturnTypesDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return types. Raises ValueError on faliure.)"; + +/// CRTP base class for Python classes representing MLIR Op interfaces. +/// Interface hierarchies are flat so no base class is expected here. The +/// derived class is expected to define the following static fields: +/// - `const char *pyClassName` - the name of the Python class to create; +/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID +/// of the interface. +/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind +/// interface-specific methods. +/// +/// An interface class may be constructed from either an Operation/OpView object +/// or from a subclass of OpView. In the latter case, only the static interface +/// methods are available, similarly to calling ConcereteOp::staticMethod on the +/// C++ side. Implementations of concrete interfaces can use the `isStatic` +/// method to check whether the interface object was constructed from a class or +/// an operation/opview instance. The `getOpName` always succeeds and returns a +/// canonical name of the operation suitable for lookups. +template +class PyConcreteOpInterface { +protected: + using ClassTy = py::class_; + using GetTypeIDFunctionTy = MlirTypeID (*)(); + +public: + /// Constructs an interface instance from an object that is either an + /// operation or a subclass of OpView. In the latter case, only the static + /// methods of the interface are accessible to the caller. + PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + : obj(object) { + try { + operation = &py::cast(obj); + } catch (py::cast_error &err) { + // Do nothing. + } + + try { + operation = &py::cast(obj).getOperation(); + } catch (py::cast_error &err) { + // Do nothing. + } + + if (operation != nullptr) { + if (!mlirOperationImplementsInterface(*operation, + ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + + MlirIdentifier identifier = mlirOperationGetName(*operation); + MlirStringRef stringRef = mlirIdentifierStr(identifier); + opName = std::string(stringRef.data, stringRef.length); + } else { + try { + opName = obj.attr("OPERATION_NAME").template cast(); + } catch (py::cast_error &err) { + throw py::type_error( + "Op interface does not refer to an operation or OpView class"); + } + + if (!mlirOperationImplementsInterfaceStatic( + mlirStringRefCreate(opName.data(), opName.length()), + context.resolve().get(), ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + } + } + + /// Creates the Python bindings for this class in the given module. + static void bind(py::module &m) { + py::class_ cls(m, "InferTypeOpInterface", + py::module_local()); + cls.def(py::init(), py::arg("object"), + py::arg("context") = py::none(), constructorDoc) + .def_property_readonly( + "operation", &PyConcreteOpInterface::getOperation, operationDoc) + .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, + opviewDoc); + ConcreteIface::bindDerived(cls); + } + + /// Hook for derived classes to add class-specific bindings. + static void bindDerived(ClassTy &cls) {} + + /// Returns `true` if this object was constructed from a subclass of OpView + /// rather than from an operation instance. + bool isStatic() { return operation == nullptr; } + + /// Returns the operation instance from which this object was constructed. + /// Throws a type error if this object was constructed from a subclass of + /// OpView. + PyOperation getOperation() { + if (operation == nullptr) { + throw py::type_error("Cannot get an operation from a static interface"); + } + + return *operation; + } + + /// Returns the opview of the operation instance from which this object was + /// constructed. Throws a type error if this object was constructed form a + /// subclass of OpView. + py::object getOpView() { + if (operation == nullptr) { + throw py::type_error("Cannot get an opview from a static interface"); + } + + return operation->createOpView(); + } + + /// Returns the canonical name of the operation this interface is constructed + /// from. + const std::string &getOpName() { return opName; } + +private: + PyOperation *operation = nullptr; + std::string opName; + py::object obj; +}; + +/// Python wrapper for InterTypeOpInterface. This interface has only static +/// methods. +class PyInferTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredTypes; + PyMlirContext &pyMlirContext; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void appendResultsCallback(intptr_t nTypes, MlirType *types, + void *userData) { + auto *data = static_cast(userData); + data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); + for (intptr_t i = 0; i < nTypes; ++i) { + data->inferredTypes.push_back( + PyType(data->pyMlirContext.getRef(), types[i])); + } + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on faliure. + std::vector + inferReturnTypes(llvm::Optional> operands, + llvm::Optional attributes, + llvm::Optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands; + llvm::SmallVector mlirRegions; + + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue &value : *operands) { + mlirOperands.push_back(value); + } + } + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + std::vector inferredTypes; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredTypes, pyContext}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result types"); + } + + return inferredTypes; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypesDoc); + } +}; + +void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -859,6 +859,7 @@ void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); +void populateIRInterfaces(pybind11::module &m); void populateIRTypes(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -85,6 +85,7 @@ populateIRCore(irModule); populateIRAffine(irModule); populateIRAttributes(irModule); + populateIRInterfaces(irModule); populateIRTypes(irModule); // Define and populate PassManager submodule. diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(ExecutionEngine) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) diff --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_public_c_api_library(MLIRCAPIInterfaces + Interfaces.cpp + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -0,0 +1,82 @@ +//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Interfaces.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; + +bool mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = + unwrap(operation)->getAbstractOperation(); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = AbstractOperation::lookup( + StringRef(operationName.data, operationName.length), unwrap(context)); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +MlirTypeID mlirInferTypeOpInterfaceTypeID() { + return wrap(InferTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData) { + StringRef name(opName.data, opName.length); + const AbstractOperation *abstractOp = + AbstractOperation::lookup(name, unwrap(context)); + if (!abstractOp) + return mlirLogicalResultFailure(); + + llvm::Optional maybeLocation = llvm::None; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + + SmallVector inferredTypes; + if (failed(abstractOp->getInterface()->inferReturnTypes( + unwrap(context), maybeLocation, unwrappedOperands, attributeDict, + unwrappedRegions, inferredTypes))) + return mlirLogicalResultFailure(); + + SmallVector wrappedInferredTypes; + wrappedInferredTypes.reserve(inferredTypes.size()); + for (Type t : inferredTypes) + wrappedInferredTypes.push_back(wrap(t)); + callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -181,6 +181,7 @@ ${PYTHON_SOURCE_DIR}/IRAffine.cpp ${PYTHON_SOURCE_DIR}/IRAttributes.cpp ${PYTHON_SOURCE_DIR}/IRCore.cpp + ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp ${PYTHON_SOURCE_DIR}/IRModule.cpp ${PYTHON_SOURCE_DIR}/IRTypes.cpp ${PYTHON_SOURCE_DIR}/PybindUtils.cpp @@ -190,6 +191,7 @@ EMBED_CAPI_LINK_LIBS MLIRCAPIDebug MLIRCAPIIR + MLIRCAPIInterfaces MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects diff --git a/mlir/python/mlir/dialects/PythonTest.td b/mlir/python/mlir/dialects/PythonTest.td --- a/mlir/python/mlir/dialects/PythonTest.td +++ b/mlir/python/mlir/dialects/PythonTest.td @@ -11,6 +11,7 @@ include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def Python_Test_Dialect : Dialect { let name = "python_test"; @@ -30,4 +31,22 @@ I32:$idx); } +def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> { + let arguments = (ins); + let results = (outs AnyInteger:$single, AnyInteger:$doubled); + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::Builder b(context); + inferredReturnTypes.push_back(b.getI32Type()); + inferredReturnTypes.push_back(b.getI64Type()); + return ::mlir::success(); + } + }]; +} + #endif // PYTHON_TEST_OPS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,3 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * + + +def register_python_test_dialect(context, load=True): + from .._mlir_libs import _mlirPythonTest + _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,6 +1,10 @@ add_subdirectory(CAPI) add_subdirectory(lib) +if (MLIR_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() + # Passed to lit.site.cfg.py.so that the out of tree Standalone dialect test # can find MLIR's CMake configuration set(MLIR_CMAKE_DIR diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/python/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lib) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,4 +1,4 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s from mlir.ir import * import mlir.dialects.python_test as test @@ -6,8 +6,10 @@ def run(f): print("\nTEST:", f.__name__) f() + return f # CHECK-LABEL: TEST: testAttributes +@run def testAttributes(): with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True @@ -127,4 +129,18 @@ del op.unit print(f"Unit: {op.unit}") -run(testAttributes) + +@run +def inferReturnTypes(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + op = test.InferResultsOp( + IntegerType.get_signless(32), IntegerType.get_signless(64)) + + iface = InferTypeOpInterface(op) + print(iface.inferReturnTypes()) + + iface_static = InferTypeOpInterface(test.InferResultsOp) + print(iface.inferReturnTypes()) diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/CMakeLists.txt @@ -0,0 +1,57 @@ +set(LLVM_OPTIONAL_SOURCES + PythonTestCAPI.cpp + PythonTestDialect.cpp + PythonTestModule.cpp +) + +set(LLVM_TARGET_DEFINITIONS python_test_ops.td) +mlir_tablegen(PythonTestDialect.h.inc -gen-dialect-decls) +mlir_tablegen(PythonTestDialect.cpp.inc -gen-dialect-defs) +mlir_tablegen(PythonTestOps.h.inc -gen-op-decls) +mlir_tablegen(PythonTestOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRPythonTestIncGen) + +add_mlir_library(MLIRPythonTestDialect + PythonTestDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + DEPENDS + MLIRPythonTestIncGen + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface + MLIRIR + MLIRSupport +) + +add_mlir_public_c_api_library(MLIRCAPIPythonTest + PythonTestCAPI.cpp + + DEPENDS + MLIRPythonTestIncGen + + LINK_LIBS PUBLIC + MLIRCAPIInterfaces + MLIRCAPIIR + MLIRCAPIRegistration +) + +declare_mlir_python_extension(MLIRPythonExtension.PythonTest + MODULE_NAME + _mlirPythonTest + + ADD_TO_PARENT + MLIRPythonTestSources + + SOURCES + PythonTestModule.cpp + + PRIVATE_LINK_LIBS + LLVMSupport + + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIRegistration + MLIRCAPIPythonTest +) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -0,0 +1,24 @@ +//===- PythonTestCAPI.h - C API for the PythonTest dialect ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H +#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H + +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -0,0 +1,14 @@ +//===- PythonTestCAPI.cpp - C API for the PythonTest dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "PythonTestDialect.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test, + python_test::PythonTestDialect); diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestDialect.h @@ -0,0 +1,21 @@ +//===- PythonTestDialect.h - PythonTest dialect definition ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H +#define MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +#include "PythonTestDialect.h.inc" + +#define GET_OP_CLASSES +#include "PythonTestOps.h.inc" + +#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -0,0 +1,25 @@ +//===- PythonTestDialect.cpp - PythonTest dialect definition --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +#include "PythonTestDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "PythonTestOps.cpp.inc" + +namespace python_test { +void PythonTestDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "PythonTestOps.cpp.inc" + >(); +} +} // namespace python_test diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -0,0 +1,26 @@ +//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_mlirPythonTest, m) { + m.def( + "register_python_test_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleRegisterDialect(pythonTestDialect, context); + if (load) { + mlirDialectHandleLoadDialect(pythonTestDialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -11,10 +11,11 @@ include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def Python_Test_Dialect : Dialect { let name = "python_test"; - let cppNamespace = "PythonTest"; + let cppNamespace = "python_test"; } class TestOp traits = []> : Op; @@ -30,4 +31,22 @@ I32:$idx); } +def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> { + let arguments = (ins); + let results = (outs AnyInteger:$single, AnyInteger:$doubled); + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::Builder b(context); + inferredReturnTypes.push_back(b.getI32Type()); + inferredReturnTypes.push_back(b.getI64Type()); + return ::mlir::success(); + } + }]; +} + #endif // PYTHON_TEST_OPS diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -333,6 +333,7 @@ "include/mlir-c/ExecutionEngine.h", "include/mlir-c/IR.h", "include/mlir-c/IntegerSet.h", + "include/mlir-c/Interfaces.h", "include/mlir-c/Pass.h", "include/mlir-c/Registration.h", "include/mlir-c/Support.h", @@ -360,6 +361,20 @@ ], ) +cc_library( + name = "CAPIInterfaces", + srcs = [ + "lib/CAPI/Interfaces/Interfaces.cpp", + ], + includes = ["include"], + deps = [ + ":CAPIIR", + ":IR", + ":InferTypeOpInterface", + "//llvm:Support", + ], +) + cc_library( name = "CAPIAsync", srcs = [ @@ -558,6 +573,7 @@ "lib/Bindings/Python/IRAffine.cpp", "lib/Bindings/Python/IRAttributes.cpp", "lib/Bindings/Python/IRCore.cpp", + "lib/Bindings/Python/IRInterfaces.cpp", "lib/Bindings/Python/IRModule.cpp", "lib/Bindings/Python/IRTypes.cpp", "lib/Bindings/Python/Pass.cpp", @@ -581,6 +597,7 @@ ":CAPIDebug", ":CAPIGPU", ":CAPIIR", + ":CAPIInterfaces", ":CAPILinalg", ":CAPIRegistration", ":CAPISparseTensor",