diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -325,6 +325,29 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos); +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an opaque type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); + +/// Creates an opaque type in the given context associated with the dialect +/// identified by its namespace. The type contains opaque byte data of the +/// specified length (data need not be null-terminated). +MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, + MlirStringRef dialectNamespace, + MlirStringRef typeData); + +/// Returns the namespace of the dialect with which the given opaque type +/// is associated. The namespace string is owned by the context. +MLIR_CAPI_EXPORTED MlirStringRef +mlirOpaqueTypeGetDialectNamespace(MlirType type); + +/// Returns the raw data as a string reference. The data remains live as long as +/// the context in which the type lives. +MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -608,6 +608,47 @@ } }; +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +/// Opaque Type subclass - OpaqueType. +class PyOpaqueType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, std::string typeData, + DefaultingPyMlirContext context) { + MlirType type = mlirOpaqueTypeGet(context->get(), + toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + py::arg("dialect_namespace"), py::arg("buffer"), + py::arg("context") = py::none(), + "Create an unregistered (opaque) dialect type."); + c.def_property_readonly( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_property_readonly( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); + } +}; + } // namespace void mlir::python::populateIRTypes(py::module &m) { @@ -627,4 +668,5 @@ PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); + PyOpaqueType::bind(m); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -11,6 +11,7 @@ #include "mlir-c/IR.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" @@ -357,3 +358,24 @@ return wrap( unwrap(type).cast().getResult(static_cast(pos))); } + +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, + MlirStringRef typeData) { + return wrap( + OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), + unwrap(typeData))); +} + +MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { + return wrap(unwrap(type).cast().getDialectNamespace().strref()); +} + +MlirStringRef mlirOpaqueTypeGetData(MlirType type) { + return wrap(unwrap(type).cast().getTypeData()); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -69,6 +69,7 @@ "Module", "NamedAttribute", "NoneType", + "OpaqueType", "OpAttributeMap", "OpOperandList", "OpResult", @@ -820,6 +821,17 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +class OpaqueType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> OpaqueType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + @property + def dialect_namespace(self) -> str: ... + @property + def data(self) -> str: ... + class OpAttributeMap: def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -795,6 +795,21 @@ fprintf(stderr, "\n"); // CHECK: (index, i1) -> (i16, i32, i64) + // Opaque type. + MlirStringRef namespace = mlirStringRefCreate("dialect", 7); + MlirStringRef data = mlirStringRefCreate("type", 4); + mlirContextSetAllowUnregisteredDialects(ctx, true); + MlirType opaque = mlirOpaqueTypeGet(ctx, namespace, data); + mlirContextSetAllowUnregisteredDialects(ctx, false); + if (!mlirTypeIsAOpaque(opaque) || + !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque), + namespace) || + !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data)) + return 25; + mlirTypeDump(opaque); + fprintf(stderr, "\n"); + // CHECK: !dialect.type + return 0; } diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -473,3 +473,17 @@ print("INPUTS:", func.inputs) # CHECK: RESULTS: [Type(index)] print("RESULTS:", func.results) + + +# CHECK-LABEL: TEST: testOpaqueType +@run +def testOpaqueType(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + opaque = OpaqueType.get("dialect", "type") + # CHECK: opaque type: !dialect.type + print("opaque type:", opaque) + # CHECK: dialect namespace: dialect + print("dialect namespace:", opaque.dialect_namespace) + # CHECK: data: type + print("data:", opaque.data)