diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -45,6 +45,9 @@ /// Returns the affine map wrapped in the given affine map attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an AffineMap attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an Array attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -89,6 +95,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); +/// Returns the typeID of a Dictionary attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -115,6 +124,9 @@ /// the value as double. MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr); +/// Returns the typeID of a Float attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -142,6 +154,9 @@ /// is of unsigned type and fits into an unsigned 64-bit integer. MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); +/// Returns the typeID of an Integer attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -162,6 +177,9 @@ /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Returns the typeID of an IntegerSet attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -185,6 +203,9 @@ /// the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr); +/// Returns the typeID of an Opaque attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -206,6 +227,9 @@ /// long as the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a String attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -239,6 +263,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an SymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an FlatSymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -270,6 +300,9 @@ /// Returns the type stored in the given type attribute. MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a Type attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -280,6 +313,9 @@ /// Creates a unit attribute in the given context. MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx); +/// Returns the typeID of a Unit attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -306,6 +342,9 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +/// Returns the typeID of an DenseArray attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void); + /// Checks whether the given attribute is a dense array attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); @@ -370,6 +409,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); +/// Returns the typeID of an DenseIntOrFPElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void); + /// Creates a dense elements attribute with the given Shaped type and elements /// in the same context as the type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( @@ -612,6 +654,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +/// Returns the typeID of a SparseElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -635,6 +680,9 @@ MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of a StridedLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -860,6 +860,9 @@ /// Gets the type id of the attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); +/// Gets the dialect of the attribute. +MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -97,6 +97,7 @@ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); } }; @@ -370,21 +371,25 @@ class mlir_attribute_subclass : public pure_subclass { public: using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); /// Subclasses by looking up the super-class dynamically. mlir_attribute_subclass(py::handle scope, const char *attrClassName, - IsAFunctionTy isaFunction) + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : mlir_attribute_subclass( scope, attrClassName, isaFunction, py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute")) {} + .attr("Attribute"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must /// be used if the subclass is being defined in the same extension module /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superCls) + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it hard, if not impossible, to properly // call chain to parent `__init__` in pybind11 due to its special handling @@ -418,6 +423,15 @@ "isinstance", [isaFunction](MlirAttribute other) { return isaFunction(other); }, py::arg("other_attribute")); + if (getTypeIDFunction) { + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction(), + pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } } }; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -80,6 +80,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; static constexpr const char *pyClassName = "AffineMapAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -259,6 +261,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; static constexpr const char *pyClassName = "ArrayAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; class PyArrayAttributeIterator { public: @@ -266,12 +270,11 @@ PyArrayAttributeIterator &dunderIter() { return *this; } - PyAttribute dunderNext() { + MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw py::stop_iteration(); - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); } static void bind(py::module &m) { @@ -286,8 +289,8 @@ int nextIndex = 0; }; - PyAttribute getItem(intptr_t i) { - return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + MlirAttribute getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); } static void bindDerived(ClassTy &c) { @@ -339,6 +342,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; static constexpr const char *pyClassName = "FloatAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -406,6 +411,10 @@ return mlirIntegerAttrGetValueUInt(self); }, "Returns the value of the integer attribute"); + c.def_property_readonly_static("static_typeid", + [](py::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } }; @@ -438,6 +447,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -464,6 +475,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; static constexpr const char *pyClassName = "OpaqueAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -501,6 +514,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; static constexpr const char *pyClassName = "StringAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -767,7 +782,7 @@ "unsupported data type for conversion to Python buffer"); } - static void bindDerived(ClassTy &c) { + static MlirAttribute bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, py::arg("array"), py::arg("signless") = true, @@ -782,13 +797,12 @@ return mlirDenseElementsAttrIsSplat(self); }) .def("get_splat_value", - [](PyDenseElementsAttribute &self) -> PyAttribute { + [](PyDenseElementsAttribute &self) -> MlirAttribute { if (!mlirDenseElementsAttrIsSplat(self)) { throw py::value_error( "get_splat_value called on a non-splat attribute"); } - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)); + return mlirDenseElementsAttrGetSplatValue(self); }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } @@ -921,6 +935,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; static constexpr const char *pyClassName = "DictAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } @@ -958,7 +974,7 @@ if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(self.getContext(), attr); + return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { @@ -1013,6 +1029,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; static constexpr const char *pyClassName = "TypeAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1035,6 +1053,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; static constexpr const char *pyClassName = "UnitAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1054,6 +1074,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; static constexpr const char *pyClassName = "StridedLayoutAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1099,6 +1121,50 @@ } }; +py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + std::string msg = + std::string( + "Can't cast unknown element type DenseIntOrFPElementsAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { + if (PyBoolAttribute::isaFunction(pyAttribute)) + return py::cast(PyBoolAttribute(pyAttribute)); + if (PyIntegerAttribute::isaFunction(pyAttribute)) + return py::cast(PyIntegerAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1118,6 +1184,9 @@ PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); PyDenseF64ArrayAttribute::bind(m); PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseArrayAttrGetTypeID(), + pybind11::cpp_function(denseArrayAttributeCaster)); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); @@ -1125,6 +1194,10 @@ PyDenseElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseIntOrFPElementsAttrGetTypeID(), + pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDictAttribute::bind(m); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1132,6 +1205,9 @@ PyIntegerAttribute::bind(m); PyStringAttribute::bind(m); PyTypeAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirIntegerAttrGetTypeID(), + pybind11::cpp_function(integerOrBoolAttributeCaster)); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1907,19 +1907,17 @@ erase(py::cast(operation)); } -PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { operation->checkValid(); symbol.getOperation().checkValid(); MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute( - symbol.getOperation().getContext(), - mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); + return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } -PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { // Op must already be a symbol. PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -1928,7 +1926,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); + return existingNameAttr; } void PySymbolTable::setSymbolName(PyOperationBase &symbol, @@ -1946,7 +1944,7 @@ mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); } -PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { PyOperation &operation = symbol.getOperation(); operation.checkValid(); MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); @@ -1954,7 +1952,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) throw py::value_error("Expected operation to have a symbol visibility."); - return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); + return existingVisAttr; } void PySymbolTable::setVisibility(PyOperationBase &symbol, @@ -2286,13 +2284,13 @@ PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - PyAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(operation->getContext(), attr); + return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { @@ -2640,10 +2638,7 @@ "Context that owns the Location") .def_property_readonly( "attr", - [](PyLocation &self) { - return PyAttribute(self.getContext(), - mlirLocationGetAttribute(self)); - }, + [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") .def( "emit_error", @@ -3135,11 +3130,11 @@ "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { PyMlirContext::ErrorCapture errors(context->getRef()); - MlirAttribute type = mlirAttributeParseGet( + MlirAttribute attr = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - if (mlirAttributeIsNull(type)) + if (mlirAttributeIsNull(attr)) throw MLIRError("Unable to parse attribute", errors.take()); - return PyAttribute(context->getRef(), type); + return attr; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " @@ -3175,18 +3170,41 @@ return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and are - // printed. This may need to be re-evaluated if debug dumps end up - // being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirAttributeGetDialect(self)); + if (!typeCaster) + return py::cast(self); + return typeCaster.value()(self); }); //---------------------------------------------------------------------------- @@ -3216,13 +3234,7 @@ "The name of the NamedAttribute binding") .def_property_readonly( "attr", - [](PyNamedAttribute &self) { - // TODO: When named attribute is removed/refactored, also remove - // this constructor (it does an inefficient table lookup). - auto contextRef = PyMlirContext::forContext( - mlirAttributeGetContext(self.namedAttr.attribute)); - return PyAttribute(std::move(contextRef), self.namedAttr.attribute); - }, + [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -986,6 +986,8 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) @@ -1017,6 +1019,34 @@ pybind11::arg("other")); cls.def_property_readonly( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + }); + cls.def_property_readonly("typeid", [](PyAttribute &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + })); + } + DerivedTy::bindDerived(cls); } @@ -1144,14 +1174,14 @@ /// Inserts the given operation into the symbol table. The operation must have /// the symbol trait. - PyAttribute insert(PyOperationBase &symbol); + MlirAttribute insert(PyOperationBase &symbol); /// Gets and sets the name of a symbol op. - static PyAttribute getSymbolName(PyOperationBase &symbol); + static MlirAttribute getSymbolName(PyOperationBase &symbol); static void setSymbolName(PyOperationBase &symbol, const std::string &name); /// Gets and sets the visibility of a symbol op. - static PyAttribute getVisibility(PyOperationBase &symbol); + static MlirAttribute getVisibility(PyOperationBase &symbol); static void setVisibility(PyOperationBase &symbol, const std::string &visibility); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -485,11 +485,12 @@ py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { + "encoding", + [](PyRankedTensorType &self) -> std::optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; - return PyAttribute(self.getContext(), encoding); + return encoding; }); } }; @@ -550,9 +551,8 @@ py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly( "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") .def_property_readonly( @@ -564,9 +564,11 @@ "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyMemRefType &self) -> MlirAttribute { + MlirAttribute memorySpace = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(memorySpace)) + throw std::runtime_error("Null memory space attribute."); + return memorySpace; }, "Returns the memory space of the given MemRef type."); } @@ -602,9 +604,11 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyUnrankedMemRefType &self) -> MlirAttribute { + MlirAttribute memorySpace = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(memorySpace)) + throw std::runtime_error("Null memory space attribute."); + return memorySpace; }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -44,6 +44,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirAffineMapAttrGetTypeID(void) { + return wrap(AffineMapAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -68,6 +72,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } +MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -102,6 +108,10 @@ return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } +MlirTypeID mlirDictionaryAttrGetTypeID(void) { + return wrap(DictionaryAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -124,6 +134,8 @@ return llvm::cast(unwrap(attr)).getValueAsDouble(); } +MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -148,6 +160,10 @@ return llvm::cast(unwrap(attr)).getUInt(); } +MlirTypeID mlirIntegerAttrGetTypeID(void) { + return wrap(IntegerAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -172,6 +188,10 @@ return llvm::isa(unwrap(attr)); } +MlirTypeID mlirIntegerSetAttrGetTypeID(void) { + return wrap(IntegerSetAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -197,6 +217,10 @@ return wrap(llvm::cast(unwrap(attr)).getAttrData()); } +MlirTypeID mlirOpaqueAttrGetTypeID(void) { + return wrap(OpaqueAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -217,6 +241,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirStringAttrGetTypeID(void) { + return wrap(StringAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -257,6 +285,10 @@ llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } +MlirTypeID mlirSymbolRefAttrGetTypeID(void) { + return wrap(SymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -273,6 +305,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { + return wrap(FlatSymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -289,6 +325,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -301,6 +339,8 @@ return wrap(UnitAttr::get(unwrap(ctx))); } +MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -329,8 +369,13 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +MlirTypeID mlirDenseArrayAttrGetTypeID() { + return wrap(DenseArrayAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); @@ -356,6 +401,7 @@ //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values) { @@ -395,6 +441,7 @@ //===----------------------------------------------------------------------===// // Accessors. +//===----------------------------------------------------------------------===// intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).size(); @@ -402,6 +449,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; @@ -431,19 +479,27 @@ //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } +MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { + return wrap(DenseIntOrFPElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, @@ -620,6 +676,7 @@ //===----------------------------------------------------------------------===// // Splat accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast(unwrap(attr)).isSplat(); @@ -663,6 +720,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; @@ -705,6 +763,7 @@ //===----------------------------------------------------------------------===// // Raw data accessors. +//===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( @@ -876,6 +935,10 @@ return wrap(llvm::cast(unwrap(attr)).getValues()); } +MlirTypeID mlirSparseElementsAttrGetTypeID(void) { + return wrap(SparseElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -903,3 +966,7 @@ int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getStrides()[pos]; } + +MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { + return wrap(StridedLayoutAttr::getTypeID()); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -870,6 +870,10 @@ return wrap(unwrap(attr).getTypeID()); } +MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { + return wrap(&unwrap(attr).getDialect()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -266,6 +266,10 @@ # CHECK: #python_test.test_attr print(a) + a = Attribute.parse("#python_test.test_attr") + # CHECK: #python_test.test_attr + print(a) + # The following cast must not assert. b = test.TestAttr(a) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -23,7 +23,7 @@ gc.collect() # CHECK: "hello" print(str(t)) - # CHECK: Attribute("hello") + # CHECK: StringAttr("hello") print(repr(t)) @@ -95,9 +95,12 @@ a1 = Attribute.parse("42") a2 = Attribute.parse("[42]") assert IntegerAttr.isinstance(a1) + assert IntegerAttr.static_typeid is not None + assert isinstance(a1, IntegerAttr) assert not IntegerAttr.isinstance(a2) assert not ArrayAttr.isinstance(a1) assert ArrayAttr.isinstance(a2) + assert isinstance(a2, ArrayAttr) # CHECK-LABEL: TEST: testAttrEqDoesNotRaise @@ -134,7 +137,7 @@ a1 = Attribute.parse('"attr1"') astr = StringAttr(a1) aself = StringAttr(astr) - # CHECK: Attribute("attr1") + # CHECK: StringAttr("attr1") print(repr(astr)) try: tillegal = StringAttr(Attribute.parse("1.0")) @@ -166,7 +169,8 @@ @run def testFloatAttr(): with Context(), Location.unknown(): - fattr = FloatAttr(Attribute.parse("42.0 : f32")) + fattr = Attribute.parse("42.0 : f32") + assert isinstance(fattr, FloatAttr) # CHECK: fattr value: 42.0 print("fattr value:", fattr.value) @@ -191,18 +195,22 @@ @run def testIntegerAttr(): with Context() as ctx: - i_attr = IntegerAttr(Attribute.parse("42")) + i_attr = Attribute.parse("42") + assert isinstance(i_attr, IntegerAttr) # CHECK: i_attr value: 42 print("i_attr value:", i_attr.value) # CHECK: i_attr type: i64 print("i_attr type:", i_attr.type) - si_attr = IntegerAttr(Attribute.parse("-1 : si8")) + si_attr = Attribute.parse("-1 : si8") + assert isinstance(si_attr, IntegerAttr) # CHECK: si_attr value: -1 print("si_attr value:", si_attr.value) - ui_attr = IntegerAttr(Attribute.parse("255 : ui8")) + ui_attr = Attribute.parse("255 : ui8") + assert isinstance(ui_attr, IntegerAttr) # CHECK: ui_attr value: 255 print("ui_attr value:", ui_attr.value) - idx_attr = IntegerAttr(Attribute.parse("-1 : index")) + idx_attr = Attribute.parse("-1 : index") + assert isinstance(idx_attr, IntegerAttr) # CHECK: idx_attr value: -1 print("idx_attr value:", idx_attr.value) @@ -215,7 +223,8 @@ @run def testBoolAttr(): with Context() as ctx: - battr = BoolAttr(Attribute.parse("true")) + battr = Attribute.parse("true") + assert isinstance(battr, BoolAttr) # CHECK: iattr value: True print("iattr value:", battr.value) @@ -228,7 +237,8 @@ @run def testFlatSymbolRefAttr(): with Context() as ctx: - sattr = FlatSymbolRefAttr(Attribute.parse("@symbol")) + sattr = Attribute.parse("@symbol") + assert isinstance(sattr, FlatSymbolRefAttr) # CHECK: symattr value: symbol print("symattr value:", sattr.value) @@ -242,7 +252,8 @@ def testOpaqueAttr(): with Context() as ctx: ctx.allow_unregistered_dialects = True - oattr = OpaqueAttr(Attribute.parse("#pytest_dummy.dummyattr<>")) + oattr = Attribute.parse("#pytest_dummy.dummyattr<>") + assert isinstance(oattr, OpaqueAttr) # CHECK: oattr value: pytest_dummy print("oattr value:", oattr.dialect_namespace) # CHECK: oattr value: b'dummyattr<>' @@ -260,7 +271,8 @@ @run def testStringAttr(): with Context() as ctx: - sattr = StringAttr(Attribute.parse('"stringattr"')) + sattr = Attribute.parse('"stringattr"') + assert isinstance(sattr, StringAttr) # CHECK: sattr value: stringattr print("sattr value:", sattr.value) # CHECK: sattr value: b'stringattr' @@ -291,11 +303,11 @@ @run def testDenseIntAttr(): with Context(): - raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") + a = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> - print("attr:", raw) + print("attr:", a) - a = DenseIntElementsAttr(raw) + assert isinstance(a, DenseIntElementsAttr) assert len(a) == 6 # CHECK: 0 1 2 3 4 5 @@ -306,11 +318,10 @@ # CHECK: i32 print(ShapedType(a.type).element_type) - raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") + a = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") # CHECK: attr: dense<[true, false, true, false]> - print("attr:", raw) - - a = DenseIntElementsAttr(raw) + print("attr:", a) + assert isinstance(a, DenseIntElementsAttr) assert len(a) == 4 # CHECK: 1 0 1 0 @@ -324,32 +335,32 @@ @run def testDenseArrayGetItem(): - def print_item(AttrClass, attr_asm): - attr = AttrClass(Attribute.parse(attr_asm)) - print(f"{len(attr)}: {attr[0]}, {attr[1]}") + def print_item(attr_asm): + attr = Attribute.parse(attr_asm) + print(repr(attr)) with Context(): - # CHECK: 2: 0, 1 - print_item(DenseBoolArrayAttr, "array") - # CHECK: 2: 2, 3 - print_item(DenseI8ArrayAttr, "array") - # CHECK: 2: 4, 5 - print_item(DenseI16ArrayAttr, "array") - # CHECK: 2: 6, 7 - print_item(DenseI32ArrayAttr, "array") - # CHECK: 2: 8, 9 - print_item(DenseI64ArrayAttr, "array") - # CHECK: 2: 1.{{0+}}, 2.{{0+}} - print_item(DenseF32ArrayAttr, "array") - # CHECK: 2: 3.{{0+}}, 4.{{0+}} - print_item(DenseF64ArrayAttr, "array") + # CHECK: DenseBoolArrayAttr(array) + print_item("array") + # CHECK: DenseI8ArrayAttr(array) + print_item("array") + # CHECK: DenseI16ArrayAttr(array) + print_item("array") + # CHECK: DenseI32ArrayAttr(array) + print_item("array") + # CHECK: DenseI64ArrayAttr(array) + print_item("array") + # CHECK: DenseF32ArrayAttr(array) + print_item("array") + # CHECK: DenseF64ArrayAttr(array) + print_item("array") # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem(): def print_item(attr_asm): - attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) + attr = Attribute.parse(attr_asm) dtype = ShapedType(attr.type).element_type try: item = attr[0] @@ -393,14 +404,13 @@ @run def testDenseFPAttr(): with Context(): - raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") - # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> - - print("attr:", raw) - - a = DenseFPElementsAttr(raw) + a = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") + assert isinstance(a, DenseFPElementsAttr) assert len(a) == 4 + # CHECK: DenseFPElementsAttr(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<4xf32>) + print(repr(a)) + # CHECK: 0.0 1.0 2.0 3.0 for value in a: print(value, end=" ") @@ -461,38 +471,39 @@ @run def testTypeAttr(): with Context(): - raw = Attribute.parse("vector<4xf32>") + a = Attribute.parse("vector<4xf32>") + assert isinstance(a, TypeAttr) # CHECK: attr: vector<4xf32> - print("attr:", raw) - type_attr = TypeAttr(raw) + print("attr:", a) # CHECK: f32 - print(ShapedType(type_attr.value).element_type) + print(ShapedType(a.value).element_type) # CHECK-LABEL: TEST: testArrayAttr @run def testArrayAttr(): with Context(): - raw = Attribute.parse("[42, true, vector<4xf32>]") + a = Attribute.parse("[42, true, vector<4xf32>]") + assert isinstance(a, ArrayAttr) # CHECK: attr: [42, true, vector<4xf32>] - print("raw attr:", raw) + print("raw attr:", a) # CHECK: - 42 # CHECK: - true # CHECK: - vector<4xf32> - for attr in ArrayAttr(raw): + for attr in a: print("- ", attr) with Context(): intAttr = Attribute.parse("42") vecAttr = Attribute.parse("vector<4xf32>") boolAttr = BoolAttr.get(True) - raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) - # CHECK: attr: [vector<4xf32>, true, 42] - print("raw attr:", raw) + arr = ArrayAttr.get([vecAttr, boolAttr, intAttr]) + # CHECK: ArrayAttr([vector<4xf32>, true, 42]) + print(repr(arr)) # CHECK: - vector<4xf32> # CHECK: - true # CHECK: - 42 - arr = ArrayAttr(raw) + assert isinstance(arr, ArrayAttr) for attr in arr: print("- ", attr) # CHECK: attr[0]: vector<4xf32> @@ -583,10 +594,10 @@ # CHECK: F64Type(f64) print_container_item("dense<1.0> : tensor") - raw = Attribute.parse("vector<4xf32>") - # CHECK: attr: vector<4xf32> - print("attr:", raw) - type_attr = TypeAttr(raw) + type_attr = Attribute.parse("vector<4xf32>") + assert isinstance(type_attr, TypeAttr) + # CHECK: TypeAttr(vector<4xf32>) + print(repr(type_attr)) # CHECK: VectorType(vector<4xf32>) print(repr(type_attr.value)) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -401,8 +401,11 @@ print("memref layout:", memref_layout.layout) # CHECK: memref affine map: (d0, d1) -> (d1, d0) print("memref affine map:", memref_layout.affine_map) - # CHECK: memory space: <> - print("memory space:", memref_layout.memory_space) + try: + memref_layout.memory_space + except RuntimeError as e: + # CHECK: Null memory space attribute. + print(e) none = NoneType.get() try: diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py --- a/mlir/test/python/ir/location.py +++ b/mlir/test/python/ir/location.py @@ -31,7 +31,9 @@ def testLocationAttr(): with Context() as ctxt: loc = Location.unknown() + # CHECK: Attribute(loc(unknown)) attr = loc.attr + print(repr(attr)) clone = Location.from_attr(attr) gc.collect() # CHECK: loc: loc(unknown) diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -73,6 +73,8 @@ foo2 = m2.body.operations[0] m1.body.append(foo2) updated_name = symbol_table.insert(foo2) + # CHECK: StringAttr("foo_{{.*}}") + print(repr(updated_name)) assert foo2.name.value != "foo" assert foo2.name == updated_name @@ -116,6 +118,8 @@ # CHECK: Bar symbol: "bam" print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") + # CHECK: StringAttr("foo") + print(repr(SymbolTable.get_symbol_name(foo))) # CHECK-LABEL: testSymbolTableVisibility @@ -130,8 +134,10 @@ """ ) foo = m.operation.regions[0].blocks[0].operations[0] - # CHECK: Existing visibility: "private" - print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") + vis = SymbolTable.get_visibility(foo) + # CHECK: StringAttr("private") + print(repr(vis)) + SymbolTable.set_visibility(foo, "public") # CHECK: func public @foo print(m) diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h --- a/mlir/test/python/lib/PythonTestCAPI.h +++ b/mlir/test/python/lib/PythonTestCAPI.h @@ -23,6 +23,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context); +MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void); + MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context); diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -23,6 +23,10 @@ return wrap(python_test::TestAttrAttr::get(unwrap(context))); } +MlirTypeID mlirPythonTestTestAttributeGetTypeID() { + return wrap(python_test::TestAttrAttr::getTypeID()); +} + bool mlirTypeIsAPythonTestTestType(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -35,7 +35,8 @@ py::arg("context"), py::arg("load") = true); mlir_attribute_subclass(m, "TestAttr", - mlirAttributeIsAPythonTestTestAttribute) + mlirAttributeIsAPythonTestTestAttribute, + mlirPythonTestTestAttributeGetTypeID) .def_classmethod( "get", [](py::object cls, MlirContext ctx) {