diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -117,6 +117,7 @@ add_subdirectory(lib) # C API needs all dialects for registration, but should be built before tests. add_subdirectory(lib/CAPI) + if (MLIR_INCLUDE_TESTS) add_definitions(-DMLIR_INCLUDE_TESTS) add_custom_target(MLIRUnitTests) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -536,6 +536,68 @@ concrete = OpResult(value) ``` +#### Interfaces + +MLIR interfaces are a mechanism to interact with the IR without needing to know +specific types of operations but only some of their aspects. Operation +interfaces are available as Python classes with the same name as their C++ +counterparts. Objects of these classes can be constructed from either: + +- an object of the `Operation` class or of any `OpView` subclass; in this + case, all interface methods are available; +- a subclass of `OpView` and a context; in this case, only the *static* + interface methods are available as there is no associated operation. + +In both cases, construction of the interface raises a `ValueError` if the +operation class does not implement the interface in the given context (or, for +operations, in the context that the operation is defined in). Similarly to +attributes and types, the MLIR context may be set up by a surrounding context +manager. + +```python +from mlir.ir import Context, InferTypeOpInterface + +with Context(): + op = <...> + + # Attempt to cast the operation into an interface. + try: + iface = InferTypeOpInterface(op) + except ValueError: + print("Operation does not implement InferTypeOpInterface.") + raise + + # All methods are available on interface objects constructed from an Operation + # or an OpView. + iface.someInstanceMethod() + + # An interface object can also be constructed given an OpView subclass. It + # also needs a context in which the interface will be looked up. The context + # can be provided explicitly or set up by the surrounding context manager. + try: + iface = InferTypeOpInterface(some_dialect.SomeOp) + except ValueError: + print("SomeOp does not implement InferTypeOpInterface.") + raise + + # Calling an instance method on an interface object constructed from a class + # will raise TypeError. + try: + iface.someInstanceMethod() + except TypeError: + pass + + # One can still call static interface methods though. + iface.inferOpReturnTypes(<...>) +``` + +If an interface object was constructed from an `Operation` or an `OpView`, they +are available as `.operation` and `.opview` properties of the interface object, +respectively. + +Only a subset of operation interfaces are currently provided in Python bindings. +Attribute and type interfaces are not yet available in Python bindings. + ### Creating IR Objects Python bindings also support IR creation and manipulation. diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md --- a/mlir/docs/CAPI.md +++ b/mlir/docs/CAPI.md @@ -194,3 +194,23 @@ the inverse conversion. Once the C++ object is available, the API implementation should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast` inside other API calls. + +### Extensions for Interfaces + +Interfaces can follow the example of IR interfaces and should be placed in the +appropraite library (e.g., common interfaces in `mlir-c/Interfaces` and +dialect-specific interfaces in their dialect library). Similarly to other type +hierarchies, interfaces are not expected to have objects of their own type and +instead operate on top-level objects: `MlirAttribute`, `MlirOperation` and +`MlirType`. Static interface methods are expected to take as leading argument a +canonical identifier of the class, `MlirStringRef` with the name for operations +and `MlirTypeID` for attributes and types, followed by `MlirContext` in which +the interfaces are registered. + +Individual interfaces are expected provide a `mlirTypeID()` +function that can be used to check whether an object or a class implements this +interface using `mlirImplementsInterface` or +`mlirImplementsInterfaceStatic` functions, +respectively. Rationale: C++ `isa` only works when an object exists, static +methods are usually dispatched to using templates; lookup by `TypeID` in +`MLIRContext` works even without an object. diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Interfaces.h @@ -0,0 +1,67 @@ +//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR interface classes. It is +// intended to contain interfaces defined in lib/Interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_H +#define MLIR_C_DIALECT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Returns `true` if the given operation implements an interface identified by +/// its TypeID. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID); + +/// Returns `true` if the operation identified by its canonical string name +/// implements the interface identified by its TypeID in the given context. +/// Note that interfaces may be attached to operations in some contexts and not +/// others. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID); + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple types from functions while +/// transferring ownerhsip to the caller. The first argument is the number of +/// consecutive elements pointed to by the second argument. The third argument +/// is an opaque pointer forwarded to the callback by the caller. +typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); + +/// Infers the return types of the operation identified by its canonical given +/// the arguments that will be supplied to its generic builder. Calls `callback` +/// with the types of inferred arguments, potentially several times, on success. +/// Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_H diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/Interfaces.h @@ -0,0 +1,18 @@ +//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR interface classes. This file should not be included from C++ code other +// than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_INTERFACES_H +#define MLIR_CAPI_INTERFACES_H + +#endif // MLIR_CAPI_INTERFACES_H diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -0,0 +1,240 @@ +//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Interfaces.h" + +namespace py = pybind11; + +namespace mlir { +namespace python { + +constexpr static const char *constructorDoc = + R"(Creates an interface from a given operation/opview object or from a +subclass of OpView. Raises ValueError if the operation does not implement the +interface.)"; + +constexpr static const char *operationDoc = + R"(Returns an Operation for which the interface was constructed.)"; + +constexpr static const char *opviewDoc = + R"(Returns an OpView subclass _instance_ for which the interface was +constructed)"; + +constexpr static const char *inferReturnTypesDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return types. Raises ValueError on faliure.)"; + +/// CRTP base class for Python classes representing MLIR Op interfaces. +/// Interface hierarchies are flat so no base class is expected here. The +/// derived class is expected to define the following static fields: +/// - `const char *pyClassName` - the name of the Python class to create; +/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID +/// of the interface. +/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind +/// interface-specific methods. +/// +/// An interface class may be constructed from either an Operation/OpView object +/// or from a subclass of OpView. In the latter case, only the static interface +/// methods are available, similarly to calling ConcereteOp::staticMethod on the +/// C++ side. Implementations of concrete interfaces can use the `isStatic` +/// method to check whether the interface object was constructed from a class or +/// an operation/opview instance. The `getOpName` always succeeds and returns a +/// canonical name of the operation suitable for lookups. +template +class PyConcreteOpInterface { +protected: + using ClassTy = py::class_; + using GetTypeIDFunctionTy = MlirTypeID (*)(); + +public: + /// Constructs an interface instance from an object that is either an + /// operation or a subclass of OpView. In the latter case, only the static + /// methods of the interface are accessible to the caller. + PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + : obj(object) { + try { + operation = &py::cast(obj); + } catch (py::cast_error &err) { + // Do nothing. + } + + try { + operation = &py::cast(obj).getOperation(); + } catch (py::cast_error &err) { + // Do nothing. + } + + if (operation != nullptr) { + if (!mlirOperationImplementsInterface(*operation, + ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + + MlirIdentifier identifier = mlirOperationGetName(*operation); + MlirStringRef stringRef = mlirIdentifierStr(identifier); + opName = std::string(stringRef.data, stringRef.length); + } else { + try { + opName = obj.attr("OPERATION_NAME").template cast(); + } catch (py::cast_error &err) { + throw py::type_error( + "Op interface does not refer to an operation or OpView class"); + } + + if (!mlirOperationImplementsInterfaceStatic( + mlirStringRefCreate(opName.data(), opName.length()), + context.resolve().get(), ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + } + } + + /// Creates the Python bindings for this class in the given module. + static void bind(py::module &m) { + py::class_ cls(m, "InferTypeOpInterface", + py::module_local()); + cls.def(py::init(), py::arg("object"), + py::arg("context") = py::none(), constructorDoc) + .def_property_readonly("operation", + &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, + opviewDoc); + ConcreteIface::bindDerived(cls); + } + + /// Hook for derived classes to add class-specific bindings. + static void bindDerived(ClassTy &cls) {} + + /// Returns `true` if this object was constructed from a subclass of OpView + /// rather than from an operation instance. + bool isStatic() { return operation == nullptr; } + + /// Returns the operation instance from which this object was constructed. + /// Throws a type error if this object was constructed from a subclass of + /// OpView. + py::object getOperationObject() { + if (operation == nullptr) { + throw py::type_error("Cannot get an operation from a static interface"); + } + + return operation->getRef().releaseObject(); + } + + /// Returns the opview of the operation instance from which this object was + /// constructed. Throws a type error if this object was constructed form a + /// subclass of OpView. + py::object getOpView() { + if (operation == nullptr) { + throw py::type_error("Cannot get an opview from a static interface"); + } + + return operation->createOpView(); + } + + /// Returns the canonical name of the operation this interface is constructed + /// from. + const std::string &getOpName() { return opName; } + +private: + PyOperation *operation = nullptr; + std::string opName; + py::object obj; +}; + +/// Python wrapper for InterTypeOpInterface. This interface has only static +/// methods. +class PyInferTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredTypes; + PyMlirContext &pyMlirContext; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void appendResultsCallback(intptr_t nTypes, MlirType *types, + void *userData) { + auto *data = static_cast(userData); + data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); + for (intptr_t i = 0; i < nTypes; ++i) { + data->inferredTypes.push_back( + PyType(data->pyMlirContext.getRef(), types[i])); + } + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on faliure. + std::vector + inferReturnTypes(llvm::Optional> operands, + llvm::Optional attributes, + llvm::Optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands; + llvm::SmallVector mlirRegions; + + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue &value : *operands) { + mlirOperands.push_back(value); + } + } + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + std::vector inferredTypes; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredTypes, pyContext}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result types"); + } + + return inferredTypes; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypesDoc); + } +}; + +void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -859,6 +859,7 @@ void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); +void populateIRInterfaces(pybind11::module &m); void populateIRTypes(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -85,6 +85,7 @@ populateIRCore(irModule); populateIRAffine(irModule); populateIRAttributes(irModule); + populateIRInterfaces(irModule); populateIRTypes(irModule); // Define and populate PassManager submodule. diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(ExecutionEngine) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) diff --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_public_c_api_library(MLIRCAPIInterfaces + Interfaces.cpp + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -0,0 +1,82 @@ +//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Interfaces.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; + +bool mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = + unwrap(operation)->getAbstractOperation(); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = AbstractOperation::lookup( + StringRef(operationName.data, operationName.length), unwrap(context)); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +MlirTypeID mlirInferTypeOpInterfaceTypeID() { + return wrap(InferTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData) { + StringRef name(opName.data, opName.length); + const AbstractOperation *abstractOp = + AbstractOperation::lookup(name, unwrap(context)); + if (!abstractOp) + return mlirLogicalResultFailure(); + + llvm::Optional maybeLocation = llvm::None; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + + SmallVector inferredTypes; + if (failed(abstractOp->getInterface()->inferReturnTypes( + unwrap(context), maybeLocation, unwrappedOperands, attributeDict, + unwrappedRegions, inferredTypes))) + return mlirLogicalResultFailure(); + + SmallVector wrappedInferredTypes; + wrappedInferredTypes.reserve(inferredTypes.size()); + for (Type t : inferredTypes) + wrappedInferredTypes.push_back(wrap(t)); + callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -112,7 +112,7 @@ DIALECT_NAME memref) declare_mlir_dialect_python_bindings( - ADD_TO_PARENT MLIRPythonTestSources.Dialects + ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/PythonTest.td SOURCES dialects/python_test.py @@ -190,6 +190,7 @@ ${PYTHON_SOURCE_DIR}/IRAffine.cpp ${PYTHON_SOURCE_DIR}/IRAttributes.cpp ${PYTHON_SOURCE_DIR}/IRCore.cpp + ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp ${PYTHON_SOURCE_DIR}/IRModule.cpp ${PYTHON_SOURCE_DIR}/IRTypes.cpp ${PYTHON_SOURCE_DIR}/PybindUtils.cpp @@ -199,6 +200,7 @@ EMBED_CAPI_LINK_LIBS MLIRCAPIDebug MLIRCAPIIR + MLIRCAPIInterfaces MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects @@ -295,6 +297,20 @@ MLIRCAPITransforms ) +# TODO: This should not be included in the main Python extension. However, +# pyutting it into MLIRPythonTestSources along with the dialect declaration +# above confuses Python module loader when running under lit. +declare_mlir_python_extension(MLIRPythonExtension.PythonTest + MODULE_NAME _mlirPythonTest + ADD_TO_PARENT MLIRPythonSources.Dialects + SOURCES + ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect +) + ################################################################################ # Common CAPI dependency DSO. # All python extensions must link through one DSO which exports the CAPI, and @@ -334,7 +350,6 @@ MLIRPythonCAPI ) - add_mlir_python_modules(MLIRPythonTestModules ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir" INSTALL_PREFIX "python_packages/mlir_test/mlir" diff --git a/mlir/python/mlir/dialects/PythonTest.td b/mlir/python/mlir/dialects/PythonTest.td deleted file mode 100644 --- a/mlir/python/mlir/dialects/PythonTest.td +++ /dev/null @@ -1,33 +0,0 @@ -//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef PYTHON_TEST_OPS -#define PYTHON_TEST_OPS - -include "mlir/Bindings/Python/Attributes.td" -include "mlir/IR/OpBase.td" - -def Python_Test_Dialect : Dialect { - let name = "python_test"; - let cppNamespace = "PythonTest"; -} -class TestOp traits = []> - : Op; - -def AttributedOp : TestOp<"attributed_op"> { - let arguments = (ins I32Attr:$mandatory_i32, - OptionalAttr:$optional_i32, - UnitAttr:$unit); -} - -def PropertyOp : TestOp<"property_op"> { - let arguments = (ins I32Attr:$property, - I32:$idx); -} - -#endif // PYTHON_TEST_OPS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,3 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * + + +def register_python_test_dialect(context, load=True): + from .._mlir_libs import _mlirPythonTest + _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,6 +1,10 @@ add_subdirectory(CAPI) add_subdirectory(lib) +if (MLIR_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() + # Passed to lit.site.cfg.py.so that the out of tree Standalone dialect test # can find MLIR's CMake configuration set(MLIR_CMAKE_DIR diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/python/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS python_test_ops.td) +mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls) +mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs) +mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls) +mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRPythonTestIncGen) + +add_subdirectory(lib) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -6,8 +6,10 @@ def run(f): print("\nTEST:", f.__name__) f() + return f # CHECK-LABEL: TEST: testAttributes +@run def testAttributes(): with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True @@ -127,4 +129,47 @@ del op.unit print(f"Unit: {op.unit}") -run(testAttributes) + +# CHECK-LABEL: TEST: inferReturnTypes +@run +def inferReturnTypes(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + op = test.InferResultsOp( + IntegerType.get_signless(32), IntegerType.get_signless(64)) + dummy = test.DummyOp() + + # CHECK: [Type(i32), Type(i64)] + iface = InferTypeOpInterface(op) + print(iface.inferReturnTypes()) + + # CHECK: [Type(i32), Type(i64)] + iface_static = InferTypeOpInterface(test.InferResultsOp) + print(iface.inferReturnTypes()) + + assert isinstance(iface.opview, test.InferResultsOp) + assert iface.opview == iface.operation.opview + + try: + iface_static.opview + except TypeError: + pass + else: + assert False, ("not expected to be able to obtain an opview from a static" + " interface") + + try: + InferTypeOpInterface(dummy) + except ValueError: + pass + else: + assert False, "not expected dummy op to implement the interface" + + try: + InferTypeOpInterface(test.DummyOp) + except ValueError: + pass + else: + assert False, "not expected dummy op class to implement the interface" diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/CMakeLists.txt @@ -0,0 +1,33 @@ +set(LLVM_OPTIONAL_SOURCES + PythonTestCAPI.cpp + PythonTestDialect.cpp + PythonTestModule.cpp +) + +add_mlir_library(MLIRPythonTestDialect + PythonTestDialect.cpp + + EXCLUDE_FROM_LIBMLIR + + DEPENDS + MLIRPythonTestIncGen + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface + MLIRIR + MLIRSupport +) + +add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect + PythonTestCAPI.cpp + + DEPENDS + MLIRPythonTestIncGen + + LINK_LIBS PUBLIC + MLIRCAPIInterfaces + MLIRCAPIIR + MLIRCAPIRegistration + MLIRPythonTestDialect +) + diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -0,0 +1,24 @@ +//===- PythonTestCAPI.h - C API for the PythonTest dialect ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H +#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H + +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -0,0 +1,14 @@ +//===- PythonTestCAPI.cpp - C API for the PythonTest dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "PythonTestDialect.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test, + python_test::PythonTestDialect); diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestDialect.h @@ -0,0 +1,21 @@ +//===- PythonTestDialect.h - PythonTest dialect definition ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H +#define MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + +#include "PythonTestDialect.h.inc" + +#define GET_OP_CLASSES +#include "PythonTestOps.h.inc" + +#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -0,0 +1,25 @@ +//===- PythonTestDialect.cpp - PythonTest dialect definition --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestDialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +#include "PythonTestDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "PythonTestOps.cpp.inc" + +namespace python_test { +void PythonTestDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "PythonTestOps.cpp.inc" + >(); +} +} // namespace python_test diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -0,0 +1,26 @@ +//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_mlirPythonTest, m) { + m.def( + "register_python_test_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleRegisterDialect(pythonTestDialect, context); + if (load) { + mlirDialectHandleLoadDialect(pythonTestDialect, context); + } + }, + py::arg("context"), py::arg("load") = true); +} diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -11,10 +11,11 @@ include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def Python_Test_Dialect : Dialect { let name = "python_test"; - let cppNamespace = "PythonTest"; + let cppNamespace = "python_test"; } class TestOp traits = []> : Op; @@ -30,4 +31,25 @@ I32:$idx); } +def DummyOp : TestOp<"dummy_op"> { +} + +def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> { + let arguments = (ins); + let results = (outs AnyInteger:$single, AnyInteger:$doubled); + + let extraClassDeclaration = [{ + static ::mlir::LogicalResult inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + ::mlir::Builder b(context); + inferredReturnTypes.push_back(b.getI32Type()); + inferredReturnTypes.push_back(b.getI64Type()); + return ::mlir::success(); + } + }]; +} + #endif // PYTHON_TEST_OPS diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -333,6 +333,7 @@ "include/mlir-c/ExecutionEngine.h", "include/mlir-c/IR.h", "include/mlir-c/IntegerSet.h", + "include/mlir-c/Interfaces.h", "include/mlir-c/Pass.h", "include/mlir-c/Registration.h", "include/mlir-c/Support.h", @@ -360,6 +361,20 @@ ], ) +cc_library( + name = "CAPIInterfaces", + srcs = [ + "lib/CAPI/Interfaces/Interfaces.cpp", + ], + includes = ["include"], + deps = [ + ":CAPIIR", + ":IR", + ":InferTypeOpInterface", + "//llvm:Support", + ], +) + cc_library( name = "CAPIAsync", srcs = [ @@ -558,6 +573,7 @@ "lib/Bindings/Python/IRAffine.cpp", "lib/Bindings/Python/IRAttributes.cpp", "lib/Bindings/Python/IRCore.cpp", + "lib/Bindings/Python/IRInterfaces.cpp", "lib/Bindings/Python/IRModule.cpp", "lib/Bindings/Python/IRTypes.cpp", "lib/Bindings/Python/Pass.cpp", @@ -581,6 +597,7 @@ ":CAPIDebug", ":CAPIGPU", ":CAPIIR", + ":CAPIInterfaces", ":CAPILinalg", ":CAPIRegistration", ":CAPISparseTensor",