diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -764,12 +764,26 @@ /// shape. static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Constructs a dense complex elements attribute from an array of APInt + /// values. Each APInt value is expected to have the same bitwidth as the + /// element type of 'type'. 'type' must be a vector or tensor with static + /// shape. + static DenseElementsAttr get(ShapedType type, + ArrayRef> values); + /// Constructs a dense float elements attribute from an array of APFloat /// values. Each APFloat value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static /// shape. static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Constructs a dense complex elements attribute from an array of APFloat + /// values. Each APFloat value is expected to have the same bitwidth as the + /// element type of 'type'. 'type' must be a vector or tensor with static + /// shape. + static DenseElementsAttr get(ShapedType type, + ArrayRef> values); + /// Construct a dense elements attribute for an initializer_list of values. /// Each value is expected to be the same bitwidth of the element type of /// 'type'. 'type' must be a vector or tensor with static shape. @@ -868,6 +882,26 @@ size_t bitWidth; }; + /// A utility iterator that allows walking over the internal raw complex APInt + /// values. + class ComplexIntElementIterator + : public detail::DenseElementIndexedIteratorImpl< + ComplexIntElementIterator, std::complex, std::complex, + std::complex> { + public: + /// Accesses the raw std::complex value at this iterator position. + std::complex operator*() const; + + private: + friend DenseElementsAttr; + + /// Constructs a new iterator. + ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex); + + /// The bitwidth of the element type. + size_t bitWidth; + }; + /// Iterator for walking over APFloat values. class FloatElementIterator final : public llvm::mapped_iterator(const std::complex &)>> { + friend DenseElementsAttr; + + /// Initializes the float element iterator to the specified iterator. + ComplexFloatElementIterator(const llvm::fltSemantics &smt, + ComplexIntElementIterator it); + + public: + using reference = std::complex; + }; + //===--------------------------------------------------------------------===// // Value Querying //===--------------------------------------------------------------------===// @@ -1004,6 +1053,15 @@ IntElementIterator int_value_begin() const; IntElementIterator int_value_end() const; + /// Return the held element values as a range of complex APInts. The element + /// type of this attribute must be a complex of integer type. + llvm::iterator_range getComplexIntValues() const; + template >::value>::type> + llvm::iterator_range getValues() const { + return getComplexIntValues(); + } + /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. llvm::iterator_range getFloatValues() const; @@ -1015,6 +1073,16 @@ FloatElementIterator float_value_begin() const; FloatElementIterator float_value_end() const; + /// Return the held element values as a range of complex APFloat. The element + /// type of this attribute must be a complex of float type. + llvm::iterator_range + getComplexFloatValues() const; + template >::value>::type> + llvm::iterator_range getValues() const { + return getComplexFloatValues(); + } + /// Return the raw storage data held by this attribute. Users should generally /// not use this directly, as the internal storage format is not always in the /// form the user might expect. @@ -1120,10 +1188,17 @@ protected: friend DenseElementsAttr; + /// Constructs a dense elements attribute from an array of raw APFloat values. + /// Each APFloat value is expected to have the same bitwidth as the element + /// type of 'type'. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + ArrayRef values, bool isSplat); + /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef values); + static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + ArrayRef values, bool isSplat); /// Get or create a new dense elements attribute instance with the given raw /// data buffer. 'type' must be a vector or tensor with static shape. @@ -1343,19 +1418,34 @@ getZeroValue() const { return getZeroAPInt(); } + template + typename std::enable_if, T>::value, T>::type + getZeroValue() const { + APInt intZero = getZeroAPInt(); + return {intZero, intZero}; + } /// Get a zero for an APFloat. template typename std::enable_if::value, T>::type getZeroValue() const { return getZeroAPFloat(); } + template + typename std::enable_if, T>::value, + T>::type + getZeroValue() const { + APFloat floatZero = getZeroAPFloat(); + return {floatZero, floatZero}; + } /// Get a zero for an C++ integer, float, StringRef, or complex type. template typename std::enable_if< std::numeric_limits::is_integer || llvm::is_one_of::value || - detail::is_complex_t::value, + (detail::is_complex_t::value && + !llvm::is_one_of, + std::complex>::value), T>::type getZeroValue() const { return T(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1456,23 +1456,15 @@ } } -/// Print the integer element of the given DenseElementsAttr at 'index'. -static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os, - unsigned index, bool isSigned) { - APInt value = *std::next(attr.int_value_begin(), index); +/// Print the integer element of a DenseElementsAttr. +static void printDenseIntElement(const APInt &value, raw_ostream &os, + bool isSigned) { if (value.getBitWidth() == 1) os << (value.getBoolValue() ? "true" : "false"); else value.print(os, isSigned); } -/// Print the float element of the given DenseElementsAttr at 'index'. -static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os, - unsigned index) { - APFloat value = *std::next(attr.float_value_begin(), index); - printFloatValue(value, os); -} - static void printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, function_ref printEltFn) { @@ -1543,26 +1535,45 @@ auto elementType = type.getElementType(); // Check to see if we should format this attribute as a hex string. - // TODO: Add support for formatting complex elements nicely. auto numElements = type.getNumElements(); - if (type.getElementType().isa() || - (!attr.isSplat() && allowHex && - shouldPrintElementsAttrWithHex(numElements))) { + if (!attr.isSplat() && allowHex && + shouldPrintElementsAttrWithHex(numElements)) { ArrayRef rawData = attr.getRawData(); os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; return; } - if (elementType.isIntOrIndex()) { + if (ComplexType complexTy = elementType.dyn_cast()) { + auto printComplexValue = [&](auto complexValues, auto printFn, + raw_ostream &os, auto &&... params) { + printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { + auto complexValue = *(complexValues.begin() + index); + os << "("; + printFn(complexValue.real(), os, params...); + os << ","; + printFn(complexValue.imag(), os, params...); + os << ")"; + }); + }; + + Type complexElementType = complexTy.getElementType(); + if (complexElementType.isa()) + printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os, + /*isSigned=*/!complexElementType.isUnsignedInteger()); + else + printComplexValue(attr.getComplexFloatValues(), printFloatValue, os); + } else if (elementType.isIntOrIndex()) { bool isSigned = !elementType.isUnsignedInteger(); + auto intValues = attr.getIntValues(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - printDenseIntElement(attr, os, index, isSigned); + printDenseIntElement(*(intValues.begin() + index), os, isSigned); }); } else { assert(elementType.isa() && "unexpected element type"); + auto floatValues = attr.getFloatValues(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { - printDenseFloatElement(attr, os, index); + printFloatValue(*(floatValues.begin() + index), os); }); } } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -374,8 +374,9 @@ /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { - if (ComplexType complex = eltType.dyn_cast()) - return getDenseElementBitWidth(complex.getElementType()) * 2; + // Align the width for complex to 8 to make storage and interpretation easier. + if (ComplexType comp = eltType.dyn_cast()) + return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored // with double semantics. if (eltType.isBF16()) diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -537,6 +537,9 @@ static size_t getDenseElementStorageWidth(size_t origWidth) { return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); } +static size_t getDenseElementStorageWidth(Type elementType) { + return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); +} /// Set a bit to a specific value. static void setBit(char *rawData, size_t bitPos, bool value) { @@ -613,14 +616,15 @@ // DenseElementAttr Iterators //===----------------------------------------------------------------------===// -/// Constructs a new iterator. +//===----------------------------------------------------------------------===// +// AttributeElementIterator + DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( DenseElementsAttr attr, size_t index) : llvm::indexed_accessor_iterator( attr.getAsOpaquePointer(), index) {} -/// Accesses the Attribute value at this iterator position. Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = getFromOpaquePointer(base).cast(); Type eltTy = owner.getType().getElementType(); @@ -640,31 +644,57 @@ llvm_unreachable("unexpected element type"); } -/// Constructs a new iterator. +//===----------------------------------------------------------------------===// +// BoolElementIterator + DenseElementsAttr::BoolElementIterator::BoolElementIterator( DenseElementsAttr attr, size_t dataIndex) : DenseElementIndexedIteratorImpl( attr.getRawData().data(), attr.isSplat(), dataIndex) {} -/// Accesses the bool value at this iterator position. bool DenseElementsAttr::BoolElementIterator::operator*() const { return getBit(getData(), getDataIndex()); } -/// Constructs a new iterator. +//===----------------------------------------------------------------------===// +// IntElementIterator + DenseElementsAttr::IntElementIterator::IntElementIterator( DenseElementsAttr attr, size_t dataIndex) : DenseElementIndexedIteratorImpl( attr.getRawData().data(), attr.isSplat(), dataIndex), bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} -/// Accesses the raw APInt value at this iterator position. APInt DenseElementsAttr::IntElementIterator::operator*() const { return readBits(getData(), getDataIndex() * getDenseElementStorageWidth(bitWidth), bitWidth); } +//===----------------------------------------------------------------------===// +// ComplexIntElementIterator + +DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( + DenseElementsAttr attr, size_t dataIndex) + : DenseElementIndexedIteratorImpl, std::complex, + std::complex>( + attr.getRawData().data(), attr.isSplat(), dataIndex) { + auto complexType = attr.getType().getElementType().cast(); + bitWidth = getDenseElementBitWidth(complexType.getElementType()); +} + +std::complex +DenseElementsAttr::ComplexIntElementIterator::operator*() const { + size_t storageWidth = getDenseElementStorageWidth(bitWidth); + size_t offset = getDataIndex() * storageWidth * 2; + return {readBits(getData(), offset, bitWidth), + readBits(getData(), offset + storageWidth, bitWidth)}; +} + +//===----------------------------------------------------------------------===// +// FloatElementIterator + DenseElementsAttr::FloatElementIterator::FloatElementIterator( const llvm::fltSemantics &smt, IntElementIterator it) : llvm::mapped_iterator(const std::complex &)>>( + it, [&](const std::complex &val) -> std::complex { + return {APFloat(smt, val.real()), APFloat(smt, val.imag())}; + }) {} + +//===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// @@ -753,7 +795,21 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrIndex()); - return DenseIntOrFPElementsAttr::getRaw(type, values); + assert(hasSameElementsOrSplat(type, values)); + size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, + /*isSplat=*/(values.size() == 1)); +} +DenseElementsAttr DenseElementsAttr::get(ShapedType type, + ArrayRef> values) { + ComplexType complex = type.getElementType().cast(); + assert(complex.getElementType().isa()); + assert(hasSameElementsOrSplat(type, values)); + size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; + ArrayRef intVals(reinterpret_cast(values.data()), + values.size() * 2); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals, + /*isSplat=*/(values.size() == 1)); } // Constructs a dense float elements attribute from an array of APFloat @@ -762,12 +818,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isa()); - - // Convert the APFloat values to APInt and create a dense elements attribute. - std::vector intValues(values.size()); - for (unsigned i = 0, e = values.size(); i != e; ++i) - intValues[i] = values[i].bitcastToAPInt(); - return DenseIntOrFPElementsAttr::getRaw(type, intValues); + assert(hasSameElementsOrSplat(type, values)); + size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values, + /*isSplat=*/(values.size() == 1)); +} +DenseElementsAttr +DenseElementsAttr::get(ShapedType type, + ArrayRef> values) { + ComplexType complex = type.getElementType().cast(); + assert(complex.getElementType().isa()); + assert(hasSameElementsOrSplat(type, values)); + ArrayRef apVals(reinterpret_cast(values.data()), + values.size() * 2); + size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; + return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals, + /*isSplat=*/(values.size() == 1)); } /// Construct a dense elements attribute from a raw buffer representing the @@ -783,8 +849,7 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, bool &detectedSplat) { - size_t elementWidth = getDenseElementBitWidth(type.getElementType()); - size_t storageWidth = getDenseElementStorageWidth(elementWidth); + size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; // Storage width of 1 is special as it is packed by the bit. @@ -904,13 +969,20 @@ assert(getType().getElementType().isIntOrIndex() && "expected integral type"); return raw_int_end(); } +auto DenseElementsAttr::getComplexIntValues() const + -> llvm::iterator_range { + Type eltTy = getType().getElementType().cast().getElementType(); + (void)eltTy; + assert(eltTy.isa() && "expected complex integral type"); + return {ComplexIntElementIterator(*this, 0), + ComplexIntElementIterator(*this, getNumElements())}; +} /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. auto DenseElementsAttr::getFloatValues() const -> llvm::iterator_range { auto elementType = getType().getElementType().cast(); - assert(elementType.isa() && "expected float type"); const auto &elementSemantics = elementType.getFloatSemantics(); return {FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; @@ -921,6 +993,14 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { return getFloatValues().end(); } +auto DenseElementsAttr::getComplexFloatValues() const + -> llvm::iterator_range { + Type eltTy = getType().getElementType().cast().getElementType(); + assert(eltTy.isa() && "expected complex float type"); + const auto &semantics = eltTy.cast().getFloatSemantics(); + return {{semantics, {*this, 0}}, + {semantics, {*this, static_cast(getNumElements())}}}; +} /// Return the raw storage data held by this attribute. ArrayRef DenseElementsAttr::getRawData() const { @@ -972,23 +1052,42 @@ // DenseIntOrFPElementsAttr //===----------------------------------------------------------------------===// +/// Utility method to write a range of APInt values to a buffer. +template +static void writeAPIntsToBuffer(size_t storageWidth, std::vector &data, + APRangeT &&values) { + data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values)); + size_t offset = 0; + for (auto it = values.begin(), e = values.end(); it != e; + ++it, offset += storageWidth) { + assert((*it).getBitWidth() <= storageWidth); + writeBits(data.data(), offset, *it); + } +} + +/// Constructs a dense elements attribute from an array of raw APFloat values. +/// Each APFloat value is expected to have the same bitwidth as the element +/// type of 'type'. 'type' must be a vector or tensor with static shape. +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, + size_t storageWidth, + ArrayRef values, + bool isSplat) { + std::vector data; + auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; + writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); + return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); +} + /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, - ArrayRef values) { - assert(hasSameElementsOrSplat(type, values)); - - size_t bitWidth = getDenseElementBitWidth(type.getElementType()); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * - values.size()); - for (unsigned i = 0, e = values.size(); i != e; ++i) { - assert(values[i].getBitWidth() == bitWidth); - writeBits(elementData.data(), i * storageBitWidth, values[i]); - } - return DenseIntOrFPElementsAttr::getRaw(type, elementData, - /*isSplat=*/(values.size() == 1)); + size_t storageWidth, + ArrayRef values, + bool isSplat) { + std::vector data; + writeAPIntsToBuffer(storageWidth, data, values); + return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat); } DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1956,29 +1956,13 @@ ArrayRef getShape() const { return shape; } private: - enum class ElementKind { Boolean, Integer, Float, String }; - - /// Return a string to represent the given element kind. - const char *getElementKindStr(ElementKind kind) { - switch (kind) { - case ElementKind::Boolean: - return "'boolean'"; - case ElementKind::Integer: - return "'integer'"; - case ElementKind::Float: - return "'float'"; - case ElementKind::String: - return "'string'"; - } - llvm_unreachable("unknown element kind"); - } - - /// Build a Dense Integer attribute for the given type. - DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); + /// Get the parsed elements for an integer attribute. + ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy, + std::vector &intValues); - /// Build a Dense Float attribute for the given type. - DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, - FloatType eltTy); + /// Get the parsed elements for a float attribute. + ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, + std::vector &floatValues); /// Build a Dense String attribute for the given type. DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); @@ -2011,9 +1995,6 @@ /// Storage used when parsing elements, this is a pair of . std::vector> storage; - /// A flag that indicates the type of elements that have been parsed. - Optional knownEltKind; - /// Storage used when parsing elements that were stored as hex values. Optional hexStorage; }; @@ -2053,22 +2034,40 @@ return nullptr; } - // If the type is an integer, build a set of APInt values from the storage - // with the correct bitwidth. - if (auto intTy = eltType.dyn_cast()) - return getIntAttr(loc, type, intTy); - if (auto indexTy = eltType.dyn_cast()) - return getIntAttr(loc, type, indexTy); - - // If parsing a floating point type. - if (auto floatTy = eltType.dyn_cast()) - return getFloatAttr(loc, type, floatTy); + // Handle complex types in the specific element type cases below. + bool isComplex = false; + if (ComplexType complexTy = eltType.dyn_cast()) { + eltType = complexTy.getElementType(); + isComplex = true; + } - // If parsing a complex type. - // TODO: Support complex elements with pretty element printing. - if (eltType.isa()) { - p.emitError(loc) << "complex elements only support hex formatting"; - return nullptr; + // Handle integer and index types. + if (eltType.isIntOrIndex()) { + std::vector intValues; + if (failed(getIntAttrElements(loc, eltType, intValues))) + return nullptr; + if (isComplex) { + // If this is a complex, treat the parsed values as complex values. + auto complexData = llvm::makeArrayRef( + reinterpret_cast *>(intValues.data()), + intValues.size() / 2); + return DenseElementsAttr::get(type, complexData); + } + return DenseElementsAttr::get(type, intValues); + } + // Handle floating point types. + if (FloatType floatTy = eltType.dyn_cast()) { + std::vector floatValues; + if (failed(getFloatAttrElements(loc, floatTy, floatValues))) + return nullptr; + if (isComplex) { + // If this is a complex, treat the parsed values as complex values. + auto complexData = llvm::makeArrayRef( + reinterpret_cast *>(floatValues.data()), + floatValues.size() / 2); + return DenseElementsAttr::get(type, complexData); + } + return DenseElementsAttr::get(type, floatValues); } // Other types are assumed to be string representations. @@ -2076,39 +2075,36 @@ } /// Build a Dense Integer attribute for the given type. -DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, - ShapedType type, Type eltTy) { - std::vector intElements; - intElements.reserve(storage.size()); - auto isUintType = type.getElementType().isUnsignedInteger(); +ParseResult +TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy, + std::vector &intValues) { + intValues.reserve(storage.size()); + bool isUintType = eltTy.isUnsignedInteger(); for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; auto tokenLoc = token.getLoc(); if (isNegative && isUintType) { - p.emitError(tokenLoc) - << "expected unsigned integer elements, but parsed negative value"; - return nullptr; + return p.emitError(tokenLoc) + << "expected unsigned integer elements, but parsed negative value"; } // Check to see if floating point values were parsed. if (token.is(Token::floatliteral)) { - p.emitError(tokenLoc) - << "expected integer elements, but parsed floating-point"; - return nullptr; + return p.emitError(tokenLoc) + << "expected integer elements, but parsed floating-point"; } assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && "unexpected token type"); if (token.isAny(Token::kw_true, Token::kw_false)) { if (!eltTy.isInteger(1)) { - p.emitError(tokenLoc) - << "expected i1 type for 'true' or 'false' values"; - return nullptr; + return p.emitError(tokenLoc) + << "expected i1 type for 'true' or 'false' values"; } APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); - intElements.push_back(apInt); + intValues.push_back(apInt); continue; } @@ -2116,19 +2112,16 @@ Optional apInt = buildAttributeAPInt(eltTy, isNegative, token.getSpelling()); if (!apInt) - return (p.emitError(tokenLoc, "integer constant out of range for type"), - nullptr); - intElements.push_back(*apInt); + return p.emitError(tokenLoc, "integer constant out of range for type"); + intValues.push_back(*apInt); } - - return DenseElementsAttr::get(type, intElements); + return success(); } /// Build a Dense Float attribute for the given type. -DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc, - ShapedType type, - FloatType eltTy) { - std::vector floatValues; +ParseResult +TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy, + std::vector &floatValues) { floatValues.reserve(storage.size()); for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; @@ -2137,34 +2130,31 @@ // Handle hexadecimal float literals. if (token.is(Token::integer) && token.getSpelling().startswith("0x")) { if (isNegative) { - p.emitError(token.getLoc()) - << "hexadecimal float literal should not have a leading minus"; - return nullptr; + return p.emitError(token.getLoc()) + << "hexadecimal float literal should not have a leading minus"; } auto val = token.getUInt64IntegerValue(); if (!val.hasValue()) { - p.emitError("hexadecimal float constant out of range for attribute"); - return nullptr; + return p.emitError( + "hexadecimal float constant out of range for attribute"); } Optional apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val); if (!apVal) - return nullptr; + return failure(); floatValues.push_back(*apVal); continue; } // Check to see if any decimal integers or booleans were parsed. - if (!token.is(Token::floatliteral)) { - p.emitError() << "expected floating-point elements, but parsed integer"; - return nullptr; - } + if (!token.is(Token::floatliteral)) + return p.emitError() + << "expected floating-point elements, but parsed integer"; // Build the float values from tokens. auto val = token.getFloatingPointValue(); - if (!val.hasValue()) { - p.emitError("floating point value too large for attribute"); - return nullptr; - } + if (!val.hasValue()) + return p.emitError("floating point value too large for attribute"); + // Treat BF16 as double because it is not supported in LLVM's APFloat. APFloat apVal(isNegative ? -*val : *val); if (!eltTy.isBF16() && !eltTy.isF64()) { @@ -2174,8 +2164,7 @@ } floatValues.push_back(apVal); } - - return DenseElementsAttr::get(type, floatValues); + return success(); } /// Build a Dense String attribute for the given type. @@ -2250,6 +2239,17 @@ storage.emplace_back(/*isNegative=*/ false, p.getToken()); p.consumeToken(); break; + + // Parse a complex element of the form '(' element ',' element ')'. + case Token::l_paren: + p.consumeToken(Token::l_paren); + if (parseElement() || + p.parseToken(Token::comma, "expected ',' between complex elements") || + parseElement() || + p.parseToken(Token::r_paren, "expected ')' after complex elements")) + return failure(); + break; + default: return p.emitError("expected element literal of primitive type"); } diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -7,11 +7,8 @@ // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> () -// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex> -"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex>} : () -> () - -// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex> -"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex>} : () -> () +// CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex> +"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex>} : () -> () // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> () diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -689,6 +689,22 @@ // ----- +"foo"(){bar = dense<[()]> : tensor>} : () -> () // expected-error {{expected element literal of primitive type}} + +// ----- + +"foo"(){bar = dense<[(10)]> : tensor>} : () -> () // expected-error {{expected ',' between complex elements}} + +// ----- + +"foo"(){bar = dense<[(10,)]> : tensor>} : () -> () // expected-error {{expected element literal of primitive type}} + +// ----- + +"foo"(){bar = dense<[(10,10]> : tensor>} : () -> () // expected-error {{expected ')' after complex elements}} + +// ----- + func @elementsattr_malformed_opaque() -> () { ^bb0: "foo"(){bar = opaque<10, "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected dialect namespace}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -702,6 +702,15 @@ "index"(){bar = dense<1> : tensor} : () -> () // CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> () "index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> () + + // CHECK: dense<(1,1)> : tensor> + "complex_attr"(){bar = dense<(1,1)> : tensor>} : () -> () + // CHECK: dense<[(1,1), (2,2)]> : tensor<2xcomplex> + "complex_attr"(){bar = dense<[(1,1), (2,2)]> : tensor<2xcomplex>} : () -> () + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + "complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> () + // CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex> + "complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex>} : () -> () return } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -27,7 +27,7 @@ EXPECT_EQ(detectedSplat, splat); for (auto newValue : detectedSplat.template getValues()) - EXPECT_EQ(newValue, splatElt); + EXPECT_TRUE(newValue == splatElt); } namespace { @@ -179,4 +179,18 @@ testSplat(complexType, value); } +TEST(DenseComplexTest, ComplexAPFloatSplat) { + MLIRContext context; + ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); + std::complex value(APFloat(10.0f), APFloat(15.0f)); + testSplat(complexType, value); +} + +TEST(DenseComplexTest, ComplexAPIntSplat) { + MLIRContext context; + ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); + std::complex value(APInt(64, 10), APInt(64, 15)); + testSplat(complexType, value); +} + } // end namespace