diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -64,9 +64,69 @@ template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; + +class DenseArrayAttributeStorage; +enum class DenseArrayAttributeElementType { I8, I16, I32, I64, F32, F64 }; } // namespace detail -/// An attribute that represents a reference to a dense vector or tensor object. +class DenseArrayAttrBase + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + detail::DenseArrayAttributeElementType getElementType(); +}; + +template +class DenseArrayAttr : public DenseArrayAttrBase { +public: + using DenseArrayAttrBase::DenseArrayAttrBase; + operator ArrayRef() { + ArrayRef raw = getImpl()->elements; + return ArrayRef(static_cast(raw.data()), raw.size()); + } +}; + +class DenseI32ArrayAttr : public DenseArrayAttr { +public: + using DenseArrayAttr::DenseArrayAttr; + static bool classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + detail::DenseArrayAttributeElementType::I32; + } +}; + +class DenseI64ArrayAttr : public DenseArrayAttr { +public: + using DenseArrayAttr::DenseArrayAttr; + static bool classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + detail::DenseArrayAttributeElementType::I64; + } +}; + +class DenseF32ArrayAttr : public DenseArrayAttr { +public: + using DenseArrayAttr::DenseArrayAttr; + static bool classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + detail::DenseArrayAttributeElementType::F32; + } +}; + +/// An attribute that represents a reference to a dense vector or tensor +/// object. /// class DenseElementsAttr : public Attribute { public: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1517,6 +1517,15 @@ let convertFromStorage = "$_self"; } +def DenseI32ArrayAttr : + ElementsAttrBase()">, + "i32 dense array attribute"> { + let storageType = [{ ::mlir::DenseI32ArrayAttr }]; + let returnType = [{ ::mlir::DenseI32ArrayAttr }]; + + let convertFromStorage = "$_self"; +} + def IndexElementsAttr : IntElementsAttrBase() .getType() diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -40,6 +40,39 @@ return eltType.getIntOrFloatBitWidth(); } +/// TODO +struct DenseArrayAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple<::mlir::ShapedType, DenseArrayAttributeElementType, + ::llvm::ArrayRef>; + DenseArrayAttributeStorage(::mlir::ShapedType type, + DenseArrayAttributeElementType eltType, + ArrayRef elements) + : ::mlir::AttributeStorage(type), eltType(eltType), elements(elements) {} + + bool operator==(const KeyTy &tblgenKey) const { + return (getType() == std::get<0>(tblgenKey)) && + (eltType == std::get<1>(tblgenKey)) && + (elements == std::get<2>(tblgenKey)); + } + + static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { + return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); + } + + static DenseArrayAttributeStorage * + construct(::mlir::AttributeStorageAllocator &allocator, + const KeyTy &tblgenKey) { + auto type = std::get<0>(tblgenKey); + auto eltType = std::get<1>(tblgenKey); + auto elements = std::get<2>(tblgenKey); + elements = allocator.copyInto(elements); + return new (allocator.allocate()) + DenseArrayAttributeStorage(type, eltType, elements); + } + DenseArrayAttributeElementType eltType; + ::llvm::ArrayRef elements; +}; + /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { public: diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -664,6 +664,14 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +detail::DenseArrayAttributeElementType DenseArrayAttrBase::getElementType() { + return getImpl()->eltType; +} + //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -263,7 +263,7 @@ def IntElementsAttrOp : TEST_Op<"int_elements_attr"> { let arguments = (ins AnyI32ElementsAttr:$any_i32_attr, - I32ElementsAttr:$i32_attr + DenseI32ArrayAttr:$i32attr ); }