diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -45,6 +45,9 @@ /// Returns the affine map wrapped in the given affine map attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an AffineMap attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an Array attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -89,6 +95,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); +/// Returns the typeID of a Dictionary attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -115,6 +124,9 @@ /// the value as double. MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr); +/// Returns the typeID of a Float attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -142,6 +154,9 @@ /// is of unsigned type and fits into an unsigned 64-bit integer. MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); +/// Returns the typeID of an Integer attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -162,6 +177,9 @@ /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Returns the typeID of an IntegerSet attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -185,6 +203,9 @@ /// the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr); +/// Returns the typeID of an Opaque attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -206,6 +227,9 @@ /// long as the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a String attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -239,6 +263,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an SymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an FlatSymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -270,6 +300,9 @@ /// Returns the type stored in the given type attribute. MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a Type attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -280,6 +313,9 @@ /// Creates a unit attribute in the given context. MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx); +/// Returns the typeID of a Unit attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -306,6 +342,9 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +/// Returns the typeID of an DenseArray attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void); + /// Checks whether the given attribute is a dense array attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); @@ -370,6 +409,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); +/// Returns the typeID of an DenseIntOrFPElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void); + /// Creates a dense elements attribute with the given Shaped type and elements /// in the same context as the type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( @@ -612,6 +654,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +/// Returns the typeID of a SparseElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -635,6 +680,9 @@ MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of a StridedLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -860,6 +860,9 @@ /// Gets the type id of the attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); +/// Gets the dialect of the attribute. +MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -97,6 +97,7 @@ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); } }; @@ -370,21 +371,25 @@ class mlir_attribute_subclass : public pure_subclass { public: using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); /// Subclasses by looking up the super-class dynamically. mlir_attribute_subclass(py::handle scope, const char *attrClassName, - IsAFunctionTy isaFunction) + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : mlir_attribute_subclass( scope, attrClassName, isaFunction, py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute")) {} + .attr("Attribute"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must /// be used if the subclass is being defined in the same extension module /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superCls) + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it hard, if not impossible, to properly // call chain to parent `__init__` in pybind11 due to its special handling @@ -418,6 +423,15 @@ "isinstance", [isaFunction](MlirAttribute other) { return isaFunction(other); }, py::arg("other_attribute")); + if (getTypeIDFunction) { + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction(), + pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } } }; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -80,6 +80,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; static constexpr const char *pyClassName = "AffineMapAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -259,6 +261,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; static constexpr const char *pyClassName = "ArrayAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; class PyArrayAttributeIterator { public: @@ -266,12 +270,11 @@ PyArrayAttributeIterator &dunderIter() { return *this; } - PyAttribute dunderNext() { + MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw py::stop_iteration(); - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); } static void bind(py::module &m) { @@ -286,8 +289,8 @@ int nextIndex = 0; }; - PyAttribute getItem(intptr_t i) { - return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + MlirAttribute getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); } static void bindDerived(ClassTy &c) { @@ -339,6 +342,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; static constexpr const char *pyClassName = "FloatAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -406,6 +411,10 @@ return mlirIntegerAttrGetValueUInt(self); }, "Returns the value of the integer attribute"); + c.def_property_readonly_static("static_typeid", + [](py::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } }; @@ -438,6 +447,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -464,6 +475,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; static constexpr const char *pyClassName = "OpaqueAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -501,6 +514,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; static constexpr const char *pyClassName = "StringAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -767,7 +782,7 @@ "unsupported data type for conversion to Python buffer"); } - static void bindDerived(ClassTy &c) { + static MlirAttribute bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, py::arg("array"), py::arg("signless") = true, @@ -782,13 +797,12 @@ return mlirDenseElementsAttrIsSplat(self); }) .def("get_splat_value", - [](PyDenseElementsAttribute &self) -> PyAttribute { + [](PyDenseElementsAttribute &self) -> MlirAttribute { if (!mlirDenseElementsAttrIsSplat(self)) { throw py::value_error( "get_splat_value called on a non-splat attribute"); } - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)); + return mlirDenseElementsAttrGetSplatValue(self); }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } @@ -921,6 +935,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; static constexpr const char *pyClassName = "DictAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } @@ -958,7 +974,7 @@ if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(self.getContext(), attr); + return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { @@ -1013,6 +1029,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; static constexpr const char *pyClassName = "TypeAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1035,6 +1053,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; static constexpr const char *pyClassName = "UnitAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1054,6 +1074,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; static constexpr const char *pyClassName = "StridedLayoutAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1099,6 +1121,50 @@ } }; +py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + std::string msg = + std::string( + "Can't cast unknown element type DenseIntOrFPElementsAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { + if (PyBoolAttribute::isaFunction(pyAttribute)) + return py::cast(PyBoolAttribute(pyAttribute)); + if (PyIntegerAttribute::isaFunction(pyAttribute)) + return py::cast(PyIntegerAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1118,6 +1184,9 @@ PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); PyDenseF64ArrayAttribute::bind(m); PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseArrayAttrGetTypeID(), + pybind11::cpp_function(denseArrayAttributeCaster)); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); @@ -1125,6 +1194,10 @@ PyDenseElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseIntOrFPElementsAttrGetTypeID(), + pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDictAttribute::bind(m); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1132,6 +1205,9 @@ PyIntegerAttribute::bind(m); PyStringAttribute::bind(m); PyTypeAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirIntegerAttrGetTypeID(), + pybind11::cpp_function(integerOrBoolAttributeCaster)); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1907,19 +1907,17 @@ erase(py::cast(operation)); } -PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { operation->checkValid(); symbol.getOperation().checkValid(); MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute( - symbol.getOperation().getContext(), - mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); + return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } -PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { // Op must already be a symbol. PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -1928,7 +1926,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); + return existingNameAttr; } void PySymbolTable::setSymbolName(PyOperationBase &symbol, @@ -1946,7 +1944,7 @@ mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); } -PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { PyOperation &operation = symbol.getOperation(); operation.checkValid(); MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); @@ -1954,7 +1952,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) throw py::value_error("Expected operation to have a symbol visibility."); - return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); + return existingVisAttr; } void PySymbolTable::setVisibility(PyOperationBase &symbol, @@ -2286,13 +2284,13 @@ PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - PyAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(operation->getContext(), attr); + return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { @@ -2640,10 +2638,7 @@ "Context that owns the Location") .def_property_readonly( "attr", - [](PyLocation &self) { - return PyAttribute(self.getContext(), - mlirLocationGetAttribute(self)); - }, + [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") .def( "emit_error", @@ -3139,7 +3134,7 @@ context->get(), toMlirStringRef(attrSpec)); if (mlirAttributeIsNull(type)) throw MLIRError("Unable to parse attribute", errors.take()); - return PyAttribute(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " @@ -3175,18 +3170,41 @@ return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and are - // printed. This may need to be re-evaluated if debug dumps end up - // being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirAttributeGetDialect(self)); + if (!typeCaster) + return py::cast(self); + return typeCaster.value()(self); }); //---------------------------------------------------------------------------- @@ -3216,13 +3234,7 @@ "The name of the NamedAttribute binding") .def_property_readonly( "attr", - [](PyNamedAttribute &self) { - // TODO: When named attribute is removed/refactored, also remove - // this constructor (it does an inefficient table lookup). - auto contextRef = PyMlirContext::forContext( - mlirAttributeGetContext(self.namedAttr.attribute)); - return PyAttribute(std::move(contextRef), self.namedAttr.attribute); - }, + [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -986,6 +986,8 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) @@ -1017,6 +1019,34 @@ pybind11::arg("other")); cls.def_property_readonly( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + }); + cls.def_property_readonly("typeid", [](PyAttribute &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + })); + } + DerivedTy::bindDerived(cls); } @@ -1144,14 +1174,14 @@ /// Inserts the given operation into the symbol table. The operation must have /// the symbol trait. - PyAttribute insert(PyOperationBase &symbol); + MlirAttribute insert(PyOperationBase &symbol); /// Gets and sets the name of a symbol op. - static PyAttribute getSymbolName(PyOperationBase &symbol); + static MlirAttribute getSymbolName(PyOperationBase &symbol); static void setSymbolName(PyOperationBase &symbol, const std::string &name); /// Gets and sets the visibility of a symbol op. - static PyAttribute getVisibility(PyOperationBase &symbol); + static MlirAttribute getVisibility(PyOperationBase &symbol); static void setVisibility(PyOperationBase &symbol, const std::string &visibility); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -485,11 +485,12 @@ py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { + "encoding", + [](PyRankedTensorType &self) -> std::optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; - return PyAttribute(self.getContext(), encoding); + return encoding; }); } }; @@ -550,9 +551,8 @@ py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly( "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") .def_property_readonly( @@ -564,9 +564,11 @@ "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyMemRefType &self) -> MlirAttribute { + MlirAttribute memorySpace = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(memorySpace)) + throw std::runtime_error("Null memory space attribute."); + return memorySpace; }, "Returns the memory space of the given MemRef type."); } @@ -602,9 +604,11 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyUnrankedMemRefType &self) -> MlirAttribute { + MlirAttribute memorySpace = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(memorySpace)) + throw std::runtime_error("Null memory space attribute."); + return memorySpace; }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -44,6 +44,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirAffineMapAttrGetTypeID(void) { + return wrap(AffineMapAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -68,6 +72,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } +MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -102,6 +108,10 @@ return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } +MlirTypeID mlirDictionaryAttrGetTypeID(void) { + return wrap(DictionaryAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -124,6 +134,8 @@ return llvm::cast(unwrap(attr)).getValueAsDouble(); } +MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -148,6 +160,10 @@ return llvm::cast(unwrap(attr)).getUInt(); } +MlirTypeID mlirIntegerAttrGetTypeID(void) { + return wrap(IntegerAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -172,6 +188,10 @@ return llvm::isa(unwrap(attr)); } +MlirTypeID mlirIntegerSetAttrGetTypeID(void) { + return wrap(IntegerSetAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -197,6 +217,10 @@ return wrap(llvm::cast(unwrap(attr)).getAttrData()); } +MlirTypeID mlirOpaqueAttrGetTypeID(void) { + return wrap(OpaqueAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -217,6 +241,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirStringAttrGetTypeID(void) { + return wrap(StringAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -257,6 +285,10 @@ llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } +MlirTypeID mlirSymbolRefAttrGetTypeID(void) { + return wrap(SymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -273,6 +305,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { + return wrap(FlatSymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -289,6 +325,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -301,6 +339,8 @@ return wrap(UnitAttr::get(unwrap(ctx))); } +MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -329,8 +369,13 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +MlirTypeID mlirDenseArrayAttrGetTypeID() { + return wrap(DenseArrayAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); @@ -356,6 +401,7 @@ //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values) { @@ -395,6 +441,7 @@ //===----------------------------------------------------------------------===// // Accessors. +//===----------------------------------------------------------------------===// intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).size(); @@ -402,6 +449,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; @@ -431,19 +479,27 @@ //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } +MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { + return wrap(DenseIntOrFPElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, @@ -620,6 +676,7 @@ //===----------------------------------------------------------------------===// // Splat accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast(unwrap(attr)).isSplat(); @@ -663,6 +720,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; @@ -705,6 +763,7 @@ //===----------------------------------------------------------------------===// // Raw data accessors. +//===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( @@ -876,6 +935,10 @@ return wrap(llvm::cast(unwrap(attr)).getValues()); } +MlirTypeID mlirSparseElementsAttrGetTypeID(void) { + return wrap(SparseElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -903,3 +966,7 @@ int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getStrides()[pos]; } + +MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { + return wrap(StridedLayoutAttr::getTypeID()); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -870,6 +870,10 @@ return wrap(unwrap(attr).getTypeID()); } +MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { + return wrap(&unwrap(attr).getDialect()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -266,6 +266,10 @@ # CHECK: #python_test.test_attr print(a) + a = Attribute.parse("#python_test.test_attr") + # CHECK: #python_test.test_attr + print(a) + # The following cast must not assert. b = test.TestAttr(a) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -23,7 +23,7 @@ gc.collect() # CHECK: "hello" print(str(t)) - # CHECK: Attribute("hello") + # CHECK: StringAttr("hello") print(repr(t)) @@ -134,7 +134,7 @@ a1 = Attribute.parse('"attr1"') astr = StringAttr(a1) aself = StringAttr(astr) - # CHECK: Attribute("attr1") + # CHECK: StringAttr("attr1") print(repr(astr)) try: tillegal = StringAttr(Attribute.parse("1.0")) @@ -324,32 +324,32 @@ @run def testDenseArrayGetItem(): - def print_item(AttrClass, attr_asm): - attr = AttrClass(Attribute.parse(attr_asm)) + def print_item(attr_asm): + attr = Attribute.parse(attr_asm) print(f"{len(attr)}: {attr[0]}, {attr[1]}") with Context(): # CHECK: 2: 0, 1 - print_item(DenseBoolArrayAttr, "array") + print_item("array") # CHECK: 2: 2, 3 - print_item(DenseI8ArrayAttr, "array") + print_item("array") # CHECK: 2: 4, 5 - print_item(DenseI16ArrayAttr, "array") + print_item("array") # CHECK: 2: 6, 7 - print_item(DenseI32ArrayAttr, "array") + print_item("array") # CHECK: 2: 8, 9 - print_item(DenseI64ArrayAttr, "array") + print_item("array") # CHECK: 2: 1.{{0+}}, 2.{{0+}} - print_item(DenseF32ArrayAttr, "array") + print_item("array") # CHECK: 2: 3.{{0+}}, 4.{{0+}} - print_item(DenseF64ArrayAttr, "array") + print_item("array") # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem(): def print_item(attr_asm): - attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) + attr = Attribute.parse(attr_asm) dtype = ShapedType(attr.type).element_type try: item = attr[0] @@ -592,3 +592,14 @@ print(repr(type_attr.value)) # CHECK: F32Type(f32) print(repr(type_attr.value.element_type)) + + +# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip +@run +def testConcreteAttributesRoundTrip(): + with Context(), Location.unknown(): + + # CHECK: FloatAttr(4.200000e+01 : f32) + print(repr(Attribute.parse("42.0 : f32"))) + + assert IntegerAttr.static_typeid is not None diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -401,8 +401,11 @@ print("memref layout:", memref_layout.layout) # CHECK: memref affine map: (d0, d1) -> (d1, d0) print("memref affine map:", memref_layout.affine_map) - # CHECK: memory space: <> - print("memory space:", memref_layout.memory_space) + try: + memref_layout.memory_space + except RuntimeError as e: + # CHECK: Null memory space attribute. + print(e) none = NoneType.get() try: diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -23,6 +23,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context); +MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void); + MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -23,6 +23,10 @@ return wrap(python_test::TestAttrAttr::get(unwrap(context))); } +MlirTypeID mlirPythonTestTestAttributeGetTypeID() { + return wrap(python_test::TestAttrAttr::getTypeID()); +} + bool mlirTypeIsAPythonTestTestType(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -35,7 +35,8 @@ py::arg("context"), py::arg("load") = true); mlir_attribute_subclass(m, "TestAttr", - mlirAttributeIsAPythonTestTestAttribute) + mlirAttributeIsAPythonTestTestAttribute, + mlirPythonTestTestAttributeGetTypeID) .def_classmethod( "get", [](py::object cls, MlirContext ctx) {