diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -45,6 +45,9 @@ /// Returns the affine map wrapped in the given affine map attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an AffineMap attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an Array attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -89,6 +95,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); +/// Returns the typeID of a Dictionary attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -115,6 +124,9 @@ /// the value as double. MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr); +/// Returns the typeID of a Float attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -142,6 +154,9 @@ /// is of unsigned type and fits into an unsigned 64-bit integer. MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); +/// Returns the typeID of an Integer attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -162,6 +177,9 @@ /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Returns the typeID of an IntegerSet attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -185,6 +203,9 @@ /// the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr); +/// Returns the typeID of an Opaque attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -206,6 +227,9 @@ /// long as the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a String attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -239,6 +263,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an SymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an FlatSymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -270,6 +300,9 @@ /// Returns the type stored in the given type attribute. MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a Type attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -280,6 +313,9 @@ /// Creates a unit attribute in the given context. MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx); +/// Returns the typeID of a Unit attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -306,6 +342,8 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void); + /// Checks whether the given attribute is a dense array attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); @@ -361,6 +399,8 @@ // Dense elements attribute. //===----------------------------------------------------------------------===// +//MLIR_CAPI_EXPORTED MlirTypeID mlirDenseElementsAttrGetTypeID(void); + // TODO: decide on the interface and add support for complex elements. // TODO: add support for APFloat and APInt to LLVM IR C API, then expose the // relevant functions here. @@ -370,6 +410,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void); + + /// Creates a dense elements attribute with the given Shaped type and elements /// in the same context as the type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( @@ -612,6 +655,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +/// Returns the typeID of a SparseElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -635,6 +681,9 @@ MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of a StridedLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -860,6 +860,9 @@ /// Gets the type id of the attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); +/// Gets the dialect of the attribute. +MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -97,6 +97,7 @@ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); } }; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -80,6 +80,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; static constexpr const char *pyClassName = "AffineMapAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -259,6 +261,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; static constexpr const char *pyClassName = "ArrayAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; class PyArrayAttributeIterator { public: @@ -339,6 +343,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; static constexpr const char *pyClassName = "FloatAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -385,6 +391,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; static constexpr const char *pyClassName = "IntegerAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -438,6 +446,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -464,6 +474,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; static constexpr const char *pyClassName = "OpaqueAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -501,6 +513,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; static constexpr const char *pyClassName = "StringAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -545,6 +559,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; static constexpr const char *pyClassName = "DenseElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; + // static constexpr GetTypeIDFunctionTy getTypeIdFunction = + // mlirDenseElementsAttrGetTypeID; static PyDenseElementsAttribute getFromBuffer(py::buffer array, bool signless, @@ -921,6 +937,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; static constexpr const char *pyClassName = "DictAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } @@ -1013,6 +1031,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; static constexpr const char *pyClassName = "TypeAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1035,6 +1055,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; static constexpr const char *pyClassName = "UnitAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1054,6 +1076,8 @@ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; static constexpr const char *pyClassName = "StridedLayoutAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1099,6 +1123,38 @@ } }; +py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + py::repr(py::cast(pyAttribute)).operator std::string() + ")"; + throw py::cast_error(msg); +} + +py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + py::repr(py::cast(pyAttribute)).operator std::string() + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1118,6 +1174,9 @@ PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); PyDenseF64ArrayAttribute::bind(m); PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseArrayAttrGetTypeID(), + pybind11::cpp_function(denseArrayAttributeCaster)); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); @@ -1125,6 +1184,10 @@ PyDenseElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseIntOrFPElementsAttrGetTypeID(), + pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDictAttribute::bind(m); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2640,10 +2640,7 @@ "Context that owns the Location") .def_property_readonly( "attr", - [](PyLocation &self) { - return PyAttribute(self.getContext(), - mlirLocationGetAttribute(self)); - }, + [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") .def( "emit_error", @@ -3139,7 +3136,7 @@ context->get(), toMlirStringRef(attrSpec)); if (mlirAttributeIsNull(type)) throw MLIRError("Unable to parse attribute", errors.take()); - return PyAttribute(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " @@ -3175,18 +3172,41 @@ return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and are - // printed. This may need to be re-evaluated if debug dumps end up - // being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirAttributeGetDialect(self)); + if (!typeCaster) + return py::cast(self); + return typeCaster.value()(self); }); //---------------------------------------------------------------------------- @@ -3216,13 +3236,7 @@ "The name of the NamedAttribute binding") .def_property_readonly( "attr", - [](PyNamedAttribute &self) { - // TODO: When named attribute is removed/refactored, also remove - // this constructor (it does an inefficient table lookup). - auto contextRef = PyMlirContext::forContext( - mlirAttributeGetContext(self.namedAttr.attribute)); - return PyAttribute(std::move(contextRef), self.namedAttr.attribute); - }, + [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -986,6 +986,8 @@ // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) @@ -1017,6 +1019,34 @@ pybind11::arg("other")); cls.def_property_readonly( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + }); + cls.def_property_readonly("typeid", [](PyAttribute &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + })); + } + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -44,6 +44,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirAffineMapAttrGetTypeID(void) { + return wrap(AffineMapAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -68,6 +72,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } +MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -102,6 +108,10 @@ return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } +MlirTypeID mlirDictionaryAttrGetTypeID(void) { + return wrap(DictionaryAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -124,6 +134,8 @@ return llvm::cast(unwrap(attr)).getValueAsDouble(); } +MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -148,6 +160,10 @@ return llvm::cast(unwrap(attr)).getUInt(); } +MlirTypeID mlirIntegerAttrGetTypeID(void) { + return wrap(IntegerAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -172,6 +188,10 @@ return llvm::isa(unwrap(attr)); } +MlirTypeID mlirIntegerSetAttrGetTypeID(void) { + return wrap(IntegerSetAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -197,6 +217,10 @@ return wrap(llvm::cast(unwrap(attr)).getAttrData()); } +MlirTypeID mlirOpaqueAttrGetTypeID(void) { + return wrap(OpaqueAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -217,6 +241,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirStringAttrGetTypeID(void) { + return wrap(StringAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -257,6 +285,10 @@ llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } +MlirTypeID mlirSymbolRefAttrGetTypeID(void) { + return wrap(SymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -273,6 +305,10 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { + return wrap(FlatSymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -289,6 +325,8 @@ return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -301,6 +339,8 @@ return wrap(UnitAttr::get(unwrap(ctx))); } +MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -329,8 +369,13 @@ // Dense array attribute. //===----------------------------------------------------------------------===// +MlirTypeID mlirDenseArrayAttrGetTypeID() { + return wrap(DenseArrayAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); @@ -356,6 +401,7 @@ //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values) { @@ -395,6 +441,7 @@ //===----------------------------------------------------------------------===// // Accessors. +//===----------------------------------------------------------------------===// intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).size(); @@ -402,6 +449,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; @@ -425,25 +473,65 @@ return llvm::cast(unwrap(attr))[pos]; } +//===----------------------------------------------------------------------===// +// TypeIDs +//===----------------------------------------------------------------------===// + +MlirTypeID mlirDenseBoolArrayAttrGetTypeID(void) { + return wrap(DenseBoolArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseI8ArrayAttrGetTypeID(void) { + return wrap(DenseI8ArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseI16ArrayAttrGetTypeID(void) { + return wrap(DenseI16ArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseI32ArrayAttrGetTypeID(void) { + return wrap(DenseI32ArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseI64ArrayAttrGetTypeID(void) { + return wrap(DenseI64ArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseF32ArrayAttrGetTypeID(void) { + return wrap(DenseF32ArrayAttr::getTypeID()); +} + +MlirTypeID mlirDenseF64ArrayAttrGetTypeID(void) { + return wrap(DenseF64ArrayAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } +MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { + return wrap(DenseIntOrFPElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, @@ -620,6 +708,7 @@ //===----------------------------------------------------------------------===// // Splat accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast(unwrap(attr)).isSplat(); @@ -663,6 +752,7 @@ //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; @@ -705,6 +795,7 @@ //===----------------------------------------------------------------------===// // Raw data accessors. +//===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( @@ -876,6 +967,10 @@ return wrap(llvm::cast(unwrap(attr)).getValues()); } +MlirTypeID mlirSparseElementsAttrGetTypeID(void) { + return wrap(SparseElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -903,3 +998,7 @@ int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getStrides()[pos]; } + +MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { + return wrap(StridedLayoutAttr::getTypeID()); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -870,6 +870,10 @@ return wrap(unwrap(attr).getTypeID()); } +MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { + return wrap(&unwrap(attr).getDialect()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -23,7 +23,7 @@ gc.collect() # CHECK: "hello" print(str(t)) - # CHECK: Attribute("hello") + # CHECK: StringAttr("hello") print(repr(t)) @@ -134,7 +134,7 @@ a1 = Attribute.parse('"attr1"') astr = StringAttr(a1) aself = StringAttr(astr) - # CHECK: Attribute("attr1") + # CHECK: StringAttr("attr1") print(repr(astr)) try: tillegal = StringAttr(Attribute.parse("1.0")) @@ -324,32 +324,32 @@ @run def testDenseArrayGetItem(): - def print_item(AttrClass, attr_asm): - attr = AttrClass(Attribute.parse(attr_asm)) + def print_item(attr_asm): + attr = Attribute.parse(attr_asm) print(f"{len(attr)}: {attr[0]}, {attr[1]}") with Context(): # CHECK: 2: 0, 1 - print_item(DenseBoolArrayAttr, "array") + print_item("array") # CHECK: 2: 2, 3 - print_item(DenseI8ArrayAttr, "array") + print_item("array") # CHECK: 2: 4, 5 - print_item(DenseI16ArrayAttr, "array") + print_item("array") # CHECK: 2: 6, 7 - print_item(DenseI32ArrayAttr, "array") + print_item("array") # CHECK: 2: 8, 9 - print_item(DenseI64ArrayAttr, "array") + print_item("array") # CHECK: 2: 1.{{0+}}, 2.{{0+}} - print_item(DenseF32ArrayAttr, "array") + print_item("array") # CHECK: 2: 3.{{0+}}, 4.{{0+}} - print_item(DenseF64ArrayAttr, "array") + print_item("array") # CHECK-LABEL: TEST: testDenseIntAttrGetItem @run def testDenseIntAttrGetItem(): def print_item(attr_asm): - attr = DenseIntElementsAttr(Attribute.parse(attr_asm)) + attr = Attribute.parse(attr_asm) dtype = ShapedType(attr.type).element_type try: item = attr[0] @@ -592,3 +592,12 @@ print(repr(type_attr.value)) # CHECK: F32Type(f32) print(repr(type_attr.value.element_type)) + + +# CHECK-LABEL: TEST: testConcreteAttributesRoundTrip +@run +def testConcreteAttributesRoundTrip(): + with Context(), Location.unknown(): + + # CHECK: FloatAttr(4.200000e+01 : f32) + print(repr(Attribute.parse("42.0 : f32")))