diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -155,8 +155,11 @@ }]; } -def Builtin_DenseArray : Builtin_Attr< - "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> { +def Builtin_DenseArray : Builtin_Attr<"DenseArray", [ + DeclareAttrInterfaceMethods, + TypedAttrInterface + ]> { let summary = "A dense array of integer or floating point elements."; let description = [{ A dense array attribute is an attribute that represents a dense array of @@ -206,24 +209,6 @@ }]>, ]; - let extraClassDeclaration = [{ - /// Allow implicit conversion to ElementsAttr. - operator ElementsAttr() const { - return *this ? cast() : nullptr; - } - - /// ElementsAttr implementation. - using ContiguousIterableTypesT = - std::tuple; - const bool *value_begin_impl(OverloadToken) const; - const int8_t *value_begin_impl(OverloadToken) const; - const int16_t *value_begin_impl(OverloadToken) const; - const int32_t *value_begin_impl(OverloadToken) const; - const int64_t *value_begin_impl(OverloadToken) const; - const float *value_begin_impl(OverloadToken) const; - const double *value_begin_impl(OverloadToken) const; - }]; - let genVerifyDecl = 1; } diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -855,9 +855,11 @@ } // namespace void DenseArrayElementParser::append(const APInt &data) { - llvm::append_range( - rawData, ArrayRef(reinterpret_cast(data.getRawData()), - data.getBitWidth() / 8)); + unsigned byteSize = data.getBitWidth() / 8; + size_t offset = rawData.size(); + rawData.insert(rawData.end(), byteSize, 0); + llvm::StoreIntToMemory( + data, reinterpret_cast(rawData.data() + offset), byteSize); ++size; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1892,11 +1892,11 @@ /// Print the integer element of a DenseElementsAttr. static void printDenseIntElement(const APInt &value, raw_ostream &os, - bool isSigned) { - if (value.getBitWidth() == 1) + Type type) { + if (type.isInteger(1)) os << (value.getBoolValue() ? "true" : "false"); else - value.print(os, isSigned); + value.print(os, !type.isUnsignedInteger()); } static void @@ -1990,14 +1990,13 @@ // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 // and hence was replaced. if (complexElementType.isa()) { - bool isSigned = !complexElementType.isUnsignedInteger(); auto valueIt = attr.value_begin>(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { auto complexValue = *(valueIt + index); os << "("; - printDenseIntElement(complexValue.real(), os, isSigned); + printDenseIntElement(complexValue.real(), os, complexElementType); os << ","; - printDenseIntElement(complexValue.imag(), os, isSigned); + printDenseIntElement(complexValue.imag(), os, complexElementType); os << ")"; }); } else { @@ -2012,10 +2011,9 @@ }); } } else if (elementType.isIntOrIndex()) { - bool isSigned = !elementType.isUnsignedInteger(); auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - printDenseIntElement(*(valueIt + index), os, isSigned); + printDenseIntElement(*(valueIt + index), os, elementType); }); } else { assert(elementType.isa() && "unexpected element type"); @@ -2036,19 +2034,17 @@ void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { Type type = attr.getElementType(); unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); + unsigned byteSize = bitwidth / 8; ArrayRef data = attr.getRawData(); auto printElementAt = [&](unsigned i) { - // FIXME: The data needs to be padded, requiring an extra copy. - SmallVector padded(data.slice(bitwidth / 8 * i, bitwidth / 8)); - padded.append( - APInt::APINT_WORD_SIZE - padded.size() % APInt::APINT_WORD_SIZE, 0); - APInt value(type.getIntOrFloatBitWidth(), - {reinterpret_cast(padded.data()), - padded.size() / APInt::APINT_WORD_SIZE}); + APInt value(bitwidth, 0); + llvm::LoadIntFromMemory( + value, reinterpret_cast(data.begin() + byteSize * i), + byteSize); // Print the data as-is or as a float. if (type.isIntOrIndex()) { - printDenseIntElement(value, getStream(), !type.isUnsignedInteger()); + printDenseIntElement(value, getStream(), type); } else { APFloat fltVal(type.cast().getFloatSemantics(), value); printFloatValue(fltVal, getStream()); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -709,28 +709,46 @@ return success(); } -const bool *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const float *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); -} -const double *DenseArrayAttr::value_begin_impl(OverloadToken) const { - return cast().asArrayRef().begin(); +//===----------------------------------------------------------------------===// +// ElementsAttr Implementation + +/// A dense array is never a splat. +bool DenseArrayAttr::isSplat() const { return false; } + +/// Try to dispatch to the given dense array subclass to contiguously iterate +/// its elements. +template +static FailureOr +tryContainerBegin(DenseArrayAttr attr) { + if (auto container = attr.dyn_cast()) + return detail::ElementsAttrIndexer::contiguous( + /*isSplat=*/false, container.asArrayRef().begin()); + return failure(); +} + +/// Re-implement `getValuesImpl` to dynamically dispatch to subclasses. +FailureOr +DenseArrayAttr::getValuesImpl(TypeID elementID) const { + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + if (elementID == TypeID::get()) + return tryContainerBegin(*this); + return failure(); } +//===----------------------------------------------------------------------===// +// DenseArrayAttrUtil + namespace { /// Instantiations of this class provide utilities for interacting with native /// data types in the context of DenseArrayAttr. @@ -814,6 +832,9 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// DenseArrayAttrImpl + template void DenseArrayAttrImpl::print(AsmPrinter &printer) const { print(printer.getStream());