diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -13,6 +13,7 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/PointerLikeTypeTraits.h" +#include namespace mlir { class AffineMap; @@ -687,6 +688,11 @@ /// Return the data base pointer. const char *getData() const { return this->base.getPointer(); } }; + +/// Type trait detector that checks if a given type T is a complex type. +template struct is_complex_t : public std::false_type {}; +template +struct is_complex_t> : public std::true_type {}; } // namespace detail /// An attribute that represents a reference to a dense vector or tensor object. @@ -724,11 +730,27 @@ /// Constructs a dense integer elements attribute from a single element. template ::is_integer || - llvm::is_one_of::value>::type> + llvm::is_one_of::value || + detail::is_complex_t::value>::type> static DenseElementsAttr get(const ShapedType &type, T value) { return get(type, llvm::makeArrayRef(value)); } + /// Constructs a dense complex elements attribute from an array of complex + /// values. Each value is expected to be the same bitwidth of the element type + /// of 'type'. 'type' must be a vector or tensor with static shape. + template ::value && + (std::numeric_limits::is_integer || + llvm::is_one_of::value)>::type> + static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { + const char *data = reinterpret_cast(values.data()); + return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), + sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed); + } + /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef values); @@ -764,6 +786,12 @@ ArrayRef rawBuffer, bool isSplatBuffer); + /// Returns true if the given buffer is a valid raw buffer for the given type. + /// `detectedSplat` is set if the buffer is valid and represents a splat + /// buffer. + static bool isValidRawBuffer(ShapedType type, ArrayRef rawBuffer, + bool &detectedSplat); + //===--------------------------------------------------------------------===// // Iterators //===--------------------------------------------------------------------===// @@ -900,12 +928,28 @@ llvm::iterator_range> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); - auto rawData = getRawData().data(); + const char *rawData = getRawData().data(); + bool splat = isSplat(); + return {ElementIterator(rawData, splat, 0), + ElementIterator(rawData, splat, getNumElements())}; + } + + /// Return the held element values as a range of std::complex. + template ::value && + (std::numeric_limits::is_integer || + llvm::is_one_of::value)>::type> + llvm::iterator_range> getValues() const { + assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); + const char *rawData = getRawData().data(); bool splat = isSplat(); return {ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } + /// Return the held element values as a range of StringRef. template ::value>::type> llvm::iterator_range> getValues() const { @@ -1011,6 +1055,13 @@ } /// Overload of the raw 'get' method that asserts that the given type is of + /// complex type. This method is used to verify type invariants that the + /// templatized 'get' method cannot. + static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. static DenseElementsAttr getRawIntOrFloat(ShapedType type, @@ -1022,6 +1073,11 @@ /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; + + /// Check the information for a C++ data type, check if this type is valid for + /// the current attribute. This method is used to verify specific type + /// invariants that the templatized 'getValues' method cannot. + bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; }; /// An attribute class for representing dense arrays of strings. The structure @@ -1075,6 +1131,13 @@ bool isSplat); /// Overload of the raw 'get' method that asserts that the given type is of + /// complex type. This method is used to verify type invariants that the + /// templatized 'get' method cannot. + static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. static DenseElementsAttr getRawIntOrFloat(ShapedType type, @@ -1287,20 +1350,15 @@ return getZeroAPFloat(); } - /// Get a zero for a StringRef. + /// Get a zero for an C++ integer, float, StringRef, or complex type. template - typename std::enable_if::value, T>::type - getZeroValue() const { - return StringRef(); - } - - /// Get a zero for an C++ integer or float type. - template - typename std::enable_if::is_integer || - llvm::is_one_of::value, - T>::type + typename std::enable_if< + std::numeric_limits::is_integer || + llvm::is_one_of::value || + detail::is_complex_t::value, + T>::type getZeroValue() const { - return T(0); + return T(); } /// Flatten, and return, all of the sparse indices in this attribute in diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1468,50 +1468,21 @@ /// Print the float element of the given DenseElementsAttr at 'index'. static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os, - unsigned index, bool isSigned) { - assert(isSigned && "floating point values are always signed"); + unsigned index) { APFloat value = *std::next(attr.float_value_begin(), index); printFloatValue(value, os); } -static void printDenseStringElement(DenseStringElementsAttr attr, - raw_ostream &os, unsigned index) { - os << "\""; - printEscapedString(attr.getRawStringData()[index], os); - os << "\""; -} - -void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, - bool allowHex) { - if (auto stringAttr = attr.dyn_cast()) { - printDenseStringElementsAttr(stringAttr); - return; - } - - printDenseIntOrFPElementsAttr(attr.cast(), - allowHex); -} - -void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, - bool allowHex) { - auto type = attr.getType(); - auto shape = type.getShape(); - auto rank = type.getRank(); - bool isSigned = !type.getElementType().isUnsignedInteger(); - - // The function used to print elements of this attribute. - auto printEltFn = type.getElementType().isIntOrIndex() - ? printDenseIntElement - : printDenseFloatElement; - +static void +printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, + function_ref printEltFn) { // Special case for 0-d and splat tensors. - if (attr.isSplat()) { - printEltFn(attr, os, 0, isSigned); - return; - } + if (isSplat) + return printEltFn(0); // Special case for degenerate tensors. auto numElements = type.getNumElements(); + int64_t rank = type.getRank(); if (numElements == 0) { for (int i = 0; i < rank; ++i) os << '['; @@ -1520,14 +1491,6 @@ return; } - // Check to see if we should format this attribute as a hex string. - if (allowHex && shouldPrintElementsAttrWithHex(numElements)) { - ArrayRef rawData = attr.getRawData(); - os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) - << "\""; - return; - } - // We use a mixed-radix counter to iterate through the shape. When we bump a // non-least-significant digit, we emit a close bracket. When we next emit an // element we re-open all closed brackets. @@ -1537,7 +1500,8 @@ // The number of brackets that have been opened and not closed. unsigned openBrackets = 0; - auto bumpCounter = [&]() { + auto shape = type.getShape(); + auto bumpCounter = [&] { // Bump the least significant digit. ++counter[rank - 1]; // Iterate backwards bubbling back the increment. @@ -1557,68 +1521,60 @@ while (openBrackets++ < rank) os << '['; openBrackets = rank; - printEltFn(attr, os, idx, isSigned); + printEltFn(idx); bumpCounter(); } while (openBrackets-- > 0) os << ']'; } -void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { - auto type = attr.getType(); - auto shape = type.getShape(); - auto rank = type.getRank(); +void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, + bool allowHex) { + if (auto stringAttr = attr.dyn_cast()) + return printDenseStringElementsAttr(stringAttr); - // Special case for 0-d and splat tensors. - if (attr.isSplat()) { - printDenseStringElement(attr, os, 0); - return; - } + printDenseIntOrFPElementsAttr(attr.cast(), + allowHex); +} - // Special case for degenerate tensors. +void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex) { + auto type = attr.getType(); + auto elementType = type.getElementType(); + + // Check to see if we should format this attribute as a hex string. + // TODO: Add support for formatting complex elements nicely. auto numElements = type.getNumElements(); - if (numElements == 0) { - for (int i = 0; i < rank; ++i) - os << '['; - for (int i = 0; i < rank; ++i) - os << ']'; + if (type.getElementType().isa() || + (!attr.isSplat() && allowHex && + shouldPrintElementsAttrWithHex(numElements))) { + ArrayRef rawData = attr.getRawData(); + os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) + << "\""; return; } - // We use a mixed-radix counter to iterate through the shape. When we bump a - // non-least-significant digit, we emit a close bracket. When we next emit an - // element we re-open all closed brackets. - - // The mixed-radix counter, with radices in 'shape'. - SmallVector counter(rank, 0); - // The number of brackets that have been opened and not closed. - unsigned openBrackets = 0; + if (elementType.isIntOrIndex()) { + bool isSigned = !elementType.isUnsignedInteger(); + printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { + printDenseIntElement(attr, os, index, isSigned); + }); + } else { + assert(elementType.isa() && "unexpected element type"); + printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { + printDenseFloatElement(attr, os, index); + }); + } +} - auto bumpCounter = [&]() { - // Bump the least significant digit. - ++counter[rank - 1]; - // Iterate backwards bubbling back the increment. - for (unsigned i = rank - 1; i > 0; --i) - if (counter[i] >= shape[i]) { - // Index 'i' is rolled over. Bump (i-1) and close a bracket. - counter[i] = 0; - ++counter[i - 1]; - --openBrackets; - os << ']'; - } +void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { + ArrayRef data = attr.getRawStringData(); + auto printFn = [&](unsigned index) { + os << "\""; + printEscapedString(data[index], os); + os << "\""; }; - - for (unsigned idx = 0, e = numElements; idx != e; ++idx) { - if (idx != 0) - os << ", "; - while (openBrackets++ < rank) - os << '['; - openBrackets = rank; - printDenseStringElement(attr, os, idx); - bumpCounter(); - } - while (openBrackets-- > 0) - os << ']'; + printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } void ModulePrinter::printType(Type type) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -374,6 +374,8 @@ /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { + if (ComplexType complex = eltType.dyn_cast()) + return getDenseElementBitWidth(complex.getElementType()) * 2; // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored // with double semantics. if (eltType.isBF16()) diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -779,24 +779,44 @@ return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); } +/// Returns true if the given buffer is a valid raw buffer for the given type. +bool DenseElementsAttr::isValidRawBuffer(ShapedType type, + ArrayRef rawBuffer, + bool &detectedSplat) { + size_t elementWidth = getDenseElementBitWidth(type.getElementType()); + size_t storageWidth = getDenseElementStorageWidth(elementWidth); + size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT; + + // Storage width of 1 is special as it is packed by the bit. + if (storageWidth == 1) { + // Check for a splat, or a buffer equal to the number of elements. + if ((detectedSplat = rawBuffer.size() == 1)) + return true; + return rawBufferWidth == llvm::alignTo<8>(type.getNumElements()); + } + // All other types are 8-bit aligned. + if ((detectedSplat = rawBufferWidth == storageWidth)) + return true; + return rawBufferWidth == (storageWidth * type.getNumElements()); +} + /// Check the information for a C++ data type, check if this type is valid for /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. -static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt, +static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, bool isSigned) { // Make sure that the data element size is the same as the type element width. - if (getDenseElementBitWidth(type.getElementType()) != + if (getDenseElementBitWidth(type) != static_cast(dataEltSize * CHAR_BIT)) return false; // Check that the element type is either float or integer or index. if (!isInt) - return type.getElementType().isa(); - - if (type.getElementType().isIndex()) + return type.isa(); + if (type.isIndex()) return true; - auto intType = type.getElementType().dyn_cast(); + auto intType = type.dyn_cast(); if (!intType) return false; @@ -807,6 +827,13 @@ } /// Defaults down the subclass implementation. +DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, + ArrayRef data, + int64_t dataEltSize, + bool isInt, bool isSigned) { + return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, + isSigned); +} DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, @@ -820,7 +847,17 @@ /// method cannot. bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const { - return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned); + return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt, + isSigned); +} + +/// Check the information for a C++ data type, check if this type is valid for +/// the current attribute. +bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, + bool isSigned) const { + return ::isValidIntOrFloat( + getType().getElementType().cast().getElementType(), + dataEltSize / 2, isInt, isSigned); } /// Returns if this attribute corresponds to a splat, i.e. if all element @@ -964,6 +1001,23 @@ type, data, isSplat); } +/// Overload of the raw 'get' method that asserts that the given type is of +/// complex type. This method is used to verify type invariants that the +/// templatized 'get' method cannot. +DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, + ArrayRef data, + int64_t dataEltSize, + bool isInt, + bool isSigned) { + assert(::isValidIntOrFloat( + type.getElementType().cast().getElementType(), + dataEltSize / 2, isInt, isSigned)); + + int64_t numElements = data.size() / dataEltSize; + assert(numElements == 1 || numElements == type.getNumElements()); + return getRaw(type, data, /*isSplat=*/numElements == 1); +} + /// Overload of the 'getRaw' method that asserts that the given type is of /// integer type. This method is used to verify type invariants that the /// templatized 'get' method cannot. @@ -971,7 +1025,8 @@ DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { - assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); + assert( + ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); int64_t numElements = data.size() / dataEltSize; assert(numElements == 1 || numElements == type.getNumElements()); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2041,7 +2041,8 @@ Type eltType = type.getElementType(); // Check to see if we parse the literal from a hex string. - if (hexStorage.hasValue() && eltType.isIntOrFloat()) + if (hexStorage.hasValue() && + (eltType.isIntOrFloat() || eltType.isa())) return getHexAttr(loc, type); // Check that the parsed storage size has the same number of elements to the @@ -2063,6 +2064,13 @@ if (auto floatTy = eltType.dyn_cast()) return getFloatAttr(loc, type, floatTy); + // If parsing a complex type. + // TODO: Support complex elements with pretty element printing. + if (eltType.isa()) { + p.emitError(loc) << "complex elements only support hex formatting"; + return nullptr; + } + // Other types are assumed to be string representations. return getStringAttr(loc, type, type.getElementType()); } @@ -2196,9 +2204,10 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, ShapedType type) { Type elementType = type.getElementType(); - if (!elementType.isa() && !elementType.isa()) { - p.emitError(loc) << "expected floating-point or integer element type, got " - << elementType; + if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { + p.emitError(loc) + << "expected floating-point, integer, or complex element type, got " + << elementType; return nullptr; } @@ -2206,21 +2215,15 @@ if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) return nullptr; - // Check that the size of the hex data corresponds to the size of the type, or - // a splat of the type. - // TODO: bf16 is currently stored as a double, this should be removed when - // APFloat properly supports it. - int64_t elementWidth = - elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth(); - if (static_cast(data.size() * CHAR_BIT) != - (type.getNumElements() * elementWidth)) { + ArrayRef rawData(data.data(), data.size()); + bool detectedSplat = false; + if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { p.emitError(loc) << "elements hex data size is invalid for provided type: " << type; return nullptr; } - return DenseElementsAttr::getFromRawBuffer( - type, ArrayRef(data.data(), data.size()), /*isSplatBuffer=*/false); + return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat); } ParseResult TensorLiteralParser::parseElement() { diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -7,6 +7,12 @@ // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> () +// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex> +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex>} : () -> () + +// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex> +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex>} : () -> () + // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> () diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -25,6 +25,9 @@ auto detectedSplat = DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt})); EXPECT_EQ(detectedSplat, splat); + + for (auto newValue : detectedSplat.template getValues()) + EXPECT_EQ(newValue, splatElt); } namespace { @@ -162,4 +165,18 @@ testSplat(stringType, stringAttr); } +TEST(DenseComplexTest, ComplexFloatSplat) { + MLIRContext context; + ComplexType complexType = ComplexType::get(FloatType::getF32(&context)); + std::complex value(10.0, 15.0); + testSplat(complexType, value); +} + +TEST(DenseComplexTest, ComplexIntSplat) { + MLIRContext context; + ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); + std::complex value(10, 15); + testSplat(complexType, value); +} + } // end namespace