diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -40,7 +40,8 @@ struct TypeAttributeStorage; /// Elements Attributes. -struct DenseElementsAttributeStorage; +struct DenseIntOrFPElementsAttributeStorage; +struct DenseStringElementsAttributeStorage; struct OpaqueElementsAttributeStorage; struct SparseElementsAttributeStorage; } // namespace detail @@ -141,10 +142,11 @@ Unit, /// Elements Attributes. - DenseElements, + DenseIntOrFPElements, + DenseStringElements, OpaqueElements, SparseElements, - FIRST_ELEMENTS_ATTR = DenseElements, + FIRST_ELEMENTS_ATTR = DenseIntOrFPElements, LAST_ELEMENTS_ATTR = SparseElements, /// Locations. @@ -671,15 +673,14 @@ /// An attribute that represents a reference to a dense vector or tensor object. /// -class DenseElementsAttr - : public Attribute::AttrBase { +class DenseElementsAttr : public ElementsAttr { public: - using Base::Base; + using ElementsAttr::ElementsAttr; /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { - return attr.getKind() == StandardAttributes::DenseElements; + return attr.getKind() == StandardAttributes::DenseIntOrFPElements || + attr.getKind() == StandardAttributes::DenseStringElements; } /// Constructs a dense elements attribute from an array of element values. @@ -712,6 +713,10 @@ /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Overload of the above 'get' method that is specialized for StringRef + /// values. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static @@ -882,6 +887,14 @@ ElementIterator(rawData, splat, getNumElements())}; } + llvm::iterator_range> getValues() const { + auto stringRefs = getRawStringData(); + const char *ptr = reinterpret_cast(stringRefs.data()); + bool splat = isSplat(); + return {ElementIterator(ptr, splat, 0), + ElementIterator(ptr, splat, getNumElements())}; + } + /// Return the held element values as a range of Attributes. llvm::iterator_range getAttributeValues() const; template getRawData() const; + /// Return the raw StringRef data held by this attribute. + ArrayRef getRawStringData() const; + //===--------------------------------------------------------------------===// // Mutation Utilities //===--------------------------------------------------------------------===// @@ -973,6 +989,60 @@ return IntElementIterator(*this, getNumElements()); } + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseElementsAttr getRawIntOrFloat(ShapedType type, + ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + /// Check the information for a C++ data type, check if this type is valid for + /// the current attribute. This method is used to verify specific type + /// invariants that the templatized 'getValues' method cannot. + bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; +}; + +/// An attribute class for representing dense arrays of strings. The structure +/// storing and querying a list of densely packed strings. +class DenseStringElementsAttr + : public Attribute::AttrBase { + +public: + using Base::Base; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::DenseStringElements; + } + + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseStringElementsAttr get(ShapedType type, ArrayRef data); + +protected: + friend DenseElementsAttr; +}; + +/// An attribute class for specializing behavior of Int and Floating-point +/// densely packed string arrays. +class DenseIntOrFPElementsAttr + : public Attribute::AttrBase { + +public: + using Base::Base; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::DenseIntOrFPElements; + } + +protected: + friend DenseElementsAttr; + /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. 'type' must be a vector or tensor with static shape. @@ -990,20 +1060,15 @@ ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); - - /// Check the information for a C++ data type, check if this type is valid for - /// the current attribute. This method is used to verify specific type - /// invariants that the templatized 'getValues' method cannot. - bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; }; /// An attribute that represents a reference to a dense float vector or tensor /// object. Each element is stored as a double. -class DenseFPElementsAttr : public DenseElementsAttr { +class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { public: using iterator = DenseElementsAttr::FloatElementIterator; - using DenseElementsAttr::DenseElementsAttr; + using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; /// Get an instance of a DenseFPElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. @@ -1035,13 +1100,13 @@ /// An attribute that represents a reference to a dense integer vector or tensor /// object. -class DenseIntElementsAttr : public DenseElementsAttr { +class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { public: /// DenseIntElementsAttr iterates on APInt, so we can use the raw element /// iterator directly. using iterator = DenseElementsAttr::IntElementIterator; - using DenseElementsAttr::DenseElementsAttr; + using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; /// Get an instance of a DenseIntElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. @@ -1266,7 +1331,7 @@ typename... Args> RetT process(Args &... args) const { switch (attrKind) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return ProcessFn()(args...); case StandardAttributes::SparseElements: return ProcessFn()(args...); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1307,6 +1307,16 @@ class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; +def StringElementsAttr : ElementsAttrBase< + CPred<"$_self.isa()" >, + "string elements attribute"> { + + let storageType = [{ DenseElementsAttr }]; + let returnType = [{ DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + // Base class for array attributes. class ArrayAttrBase : Attr { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1316,7 +1316,7 @@ << opType << ") does not match value type (" << valueType << ")"; return success(); } break; - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: case StandardAttributes::SparseElements: { if (valueType == opType) break; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -973,6 +973,9 @@ /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); + /// Print a dense string elements attribute. + void printDenseStringElementsAttr(DenseStringElementsAttr attr); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1392,7 +1395,7 @@ os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; break; } - case StandardAttributes::DenseElements: { + case StandardAttributes::DenseIntOrFPElements: { auto eltsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(eltsAttr)) { printElidedElementsAttr(os); @@ -1403,6 +1406,17 @@ os << '>'; break; } + case StandardAttributes::DenseStringElements: { + auto eltsAttr = attr.cast(); + if (printerFlags.shouldElideElementsAttr(eltsAttr)) { + printElidedElementsAttr(os); + break; + } + os << "dense<"; + printDenseStringElementsAttr(eltsAttr); + os << '>'; + break; + } case StandardAttributes::SparseElements: { auto elementsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) || @@ -1454,6 +1468,13 @@ printFloatValue(value, os); } +static void printDenseStringElement(DenseStringElementsAttr attr, + raw_ostream &os, unsigned index) { + os << "\""; + printEscapedString(attr.getRawStringData()[index], os); + os << "\""; +} + void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { auto type = attr.getType(); @@ -1526,6 +1547,63 @@ os << ']'; } +void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); + + // Special case for 0-d and splat tensors. + if (attr.isSplat()) { + printDenseStringElement(attr, os, 0); + return; + } + + // Special case for degenerate tensors. + auto numElements = type.getNumElements(); + if (numElements == 0) { + for (int i = 0; i < rank; ++i) + os << '['; + for (int i = 0; i < rank; ++i) + os << ']'; + return; + } + + // We use a mixed-radix counter to iterate through the shape. When we bump a + // non-least-significant digit, we emit a close bracket. When we next emit an + // element we re-open all closed brackets. + + // The mixed-radix counter, with radices in 'shape'. + SmallVector counter(rank, 0); + // The number of brackets that have been opened and not closed. + unsigned openBrackets = 0; + + auto bumpCounter = [&]() { + // Bump the least significant digit. + ++counter[rank - 1]; + // Iterate backwards bubbling back the increment. + for (unsigned i = rank - 1; i > 0; --i) + if (counter[i] >= shape[i]) { + // Index 'i' is rolled over. Bump (i-1) and close a bracket. + counter[i] = 0; + ++counter[i - 1]; + --openBrackets; + os << ']'; + } + }; + + for (unsigned idx = 0, e = numElements; idx != e; ++idx) { + if (idx != 0) + os << ", "; + while (openBrackets++ < rank) + os << '['; + openBrackets = rank; + printDenseStringElement(attr, os, idx); + bumpCounter(); + } + while (openBrackets-- > 0) + os << ']'; +} + void ModulePrinter::printType(Type type) { if (!type) { os << "<>"; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -385,6 +385,20 @@ /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { +public: + DenseElementsAttributeStorage(ShapedType ty, bool isSplat) + : AttributeStorage(ty), isSplat(isSplat) {} + + bool isSplat; +}; + +/// An attribute representing a reference to a dense vector or tensor object. +struct DenseIntOrFPElementsAttributeStorage + : public DenseElementsAttributeStorage { + DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) + : DenseElementsAttributeStorage(ty, isSplat), data(data) {} + struct KeyTy { KeyTy(ShapedType type, ArrayRef data, llvm::hash_code hashCode, bool isSplat = false) @@ -403,10 +417,6 @@ bool isSplat; }; - DenseElementsAttributeStorage(ShapedType ty, ArrayRef data, - bool isSplat = false) - : AttributeStorage(ty), data(data), isSplat(isSplat) {} - /// Compare this storage instance with the provided key. bool operator==(const KeyTy &key) const { if (key.type != getType()) @@ -512,7 +522,7 @@ } /// Construct a new storage instance. - static DenseElementsAttributeStorage * + static DenseIntOrFPElementsAttributeStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { // If the data buffer is non-empty, we copy it into the allocator with a // 64-bit alignment. @@ -528,12 +538,129 @@ copy = ArrayRef(rawData, data.size()); } - return new (allocator.allocate()) - DenseElementsAttributeStorage(key.type, copy, key.isSplat); + return new (allocator.allocate()) + DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat); } ArrayRef data; - bool isSplat; +}; + +/// An attribute representing a reference to a dense vector or tensor object +/// containing strings. +struct DenseStringElementsAttributeStorage + : public DenseElementsAttributeStorage { + DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) + : DenseElementsAttributeStorage(ty, isSplat), data(data) {} + + struct KeyTy { + KeyTy(ShapedType type, ArrayRef data, llvm::hash_code hashCode, + bool isSplat = false) + : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {} + + /// The type of the dense elements. + ShapedType type; + + /// The raw buffer for the data storage. + ArrayRef data; + + /// The computed hash code for the storage data. + llvm::hash_code hashCode; + + /// A boolean that indicates if this data is a splat or not. + bool isSplat; + }; + + /// Compare this storage instance with the provided key. + bool operator==(const KeyTy &key) const { + if (key.type != getType()) + return false; + + // Otherwise, we can default to just checking the data. StringRefs compare + // by contents. + return key.data == data; + } + + /// Construct a key from a shaped type, StringRef data buffer, and a flag that + /// signals if the data is already known to be a splat. Callers to this + /// function are expected to tag preknown splat values when possible, e.g. one + /// element shapes. + static KeyTy getKey(ShapedType ty, ArrayRef data, + bool isKnownSplat) { + // Handle an empty storage instance. + if (data.empty()) + return KeyTy(ty, data, 0); + + // If the data is already known to be a splat, the key hash value is + // directly the data buffer. + if (isKnownSplat) + return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); + + // Handle the simple case of only one element. + size_t numElements = ty.getNumElements(); + assert(numElements != 1 && "splat of 1 element should already be detected"); + + // Create the initial hash value with just the first element. + const auto &firstElt = data.front(); + auto hashVal = llvm::hash_value(firstElt); + + // Check to see if this storage represents a splat. If it doesn't then + // combine the hash for the data starting with the first non splat element. + for (size_t i = 1, e = data.size(); i != e; i++) + if (!firstElt.equals(data[i])) + return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i))); + + // Otherwise, this is a splat so just return the hash of the first element. + return KeyTy(ty, {firstElt}, hashVal, /*isSplat=*/true); + } + + /// Hash the key for the storage. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(key.type, key.hashCode); + } + + /// Construct a new storage instance. + static DenseStringElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + // If the data buffer is non-empty, we copy it into the allocator with a + // 64-bit alignment. + ArrayRef copy, data = key.data; + if (data.empty()) { + return new (allocator.allocate()) + DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + } + + int numEntries = key.isSplat ? 1 : data.size(); + + // Compute the amount data needed to store the ArrayRef and StringRef + // contents. + size_t dataSize = sizeof(StringRef) * numEntries; + for (int i = 0; i < numEntries; i++) + dataSize += data[i].size(); + + char *rawData = reinterpret_cast( + allocator.allocate(dataSize, alignof(uint64_t))); + + // Setup a mutable array ref of our string refs so that we can update their + // contents. + auto mutableCopy = MutableArrayRef( + reinterpret_cast(rawData), numEntries); + auto stringData = rawData + numEntries * sizeof(StringRef); + + for (int i = 0; i < numEntries; i++) { + memcpy(stringData, data[i].data(), data[i].size()); + mutableCopy[i] = StringRef(stringData, data[i].size()); + stringData += data[i].size(); + } + + copy = + ArrayRef(reinterpret_cast(rawData), numEntries); + + return new (allocator.allocate()) + DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + } + + ArrayRef data; }; /// An attribute representing a reference to a tensor constant with opaque diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -411,7 +411,7 @@ /// element, then a null attribute is returned. Attribute ElementsAttr::getValue(ArrayRef index) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().getValue(index); case StandardAttributes::OpaqueElements: return cast().getValue(index); @@ -442,7 +442,7 @@ ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().mapValues(newElementType, mapping); default: llvm_unreachable("unsupported ElementsAttr subtype"); @@ -453,7 +453,7 @@ ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().mapValues(newElementType, mapping); default: llvm_unreachable("unsupported ElementsAttr subtype"); @@ -643,7 +643,8 @@ "expected value to have same bitwidth as element type"); writeBits(data.data(), i * storageBitWidth, intVal); } - return getRaw(type, data, /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, data, + /*isSplat=*/(values.size() == 1)); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -654,7 +655,14 @@ std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); for (int i = 0, e = values.size(); i != e; ++i) setBit(buff.data(), i, values[i]); - return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, buff, + /*isSplat=*/(values.size() == 1)); +} + +DenseElementsAttr DenseElementsAttr::get(ShapedType type, + ArrayRef values) { + assert(!type.getElementType().isIntOrFloat()); + return DenseStringElementsAttr::get(type, values); } /// Constructs a dense integer elements attribute from an array of APInt @@ -663,7 +671,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isIntOrIndex()); - return getRaw(type, values); + return DenseIntOrFPElementsAttr::getRaw(type, values); } // Constructs a dense float elements attribute from an array of APFloat @@ -677,7 +685,7 @@ std::vector intValues(values.size()); for (unsigned i = 0, e = values.size(); i != e; ++i) intValues[i] = values[i].bitcastToAPInt(); - return getRaw(type, intValues); + return DenseIntOrFPElementsAttr::getRaw(type, intValues); } /// Construct a dense elements attribute from a raw buffer representing the @@ -686,34 +694,7 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef rawBuffer, bool isSplatBuffer) { - return getRaw(type, rawBuffer, isSplatBuffer); -} - -/// Constructs a dense elements attribute from an array of raw APInt values. -/// Each APInt value is expected to have the same bitwidth as the element type -/// of 'type'. -DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, - ArrayRef values) { - assert(hasSameElementsOrSplat(type, values)); - - size_t bitWidth = getDenseElementBitWidth(type.getElementType()); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * - values.size()); - for (unsigned i = 0, e = values.size(); i != e; ++i) { - assert(values[i].getBitWidth() == bitWidth); - writeBits(elementData.data(), i * storageBitWidth, values[i]); - } - return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); -} - -DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, - ArrayRef data, bool isSplat) { - assert((type.isa() || type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::DenseElements, type, - data, isSplat); + return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); } /// Check the information for a C++ data type, check if this type is valid for @@ -743,19 +724,14 @@ return intType.isSigned() ? isSigned : !isSigned; } -/// Overload of the 'getRaw' method that asserts that the given type is of -/// integer type. This method is used to verify type invariants that the -/// templatized 'get' method cannot. +/// Defaults down the subclass implementation. DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { - assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); - - int64_t numElements = data.size() / dataEltSize; - assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); + return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, + isInt, isSigned); } /// A method used to verify specific type invariants that the templatized 'get' @@ -767,7 +743,9 @@ /// Returns if this attribute corresponds to a splat, i.e. if all element /// values are the same. -bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } +bool DenseElementsAttr::isSplat() const { + return static_cast(impl)->isSplat; +} /// Return the held element values as a range of Attributes. auto DenseElementsAttr::getAttributeValues() const @@ -827,7 +805,11 @@ /// Return the raw storage data held by this attribute. ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; + return static_cast(impl)->data; +} + +ArrayRef DenseElementsAttr::getRawStringData() const { + return static_cast(impl)->data; } /// Return a new DenseElementsAttr that has the same data as the current @@ -843,7 +825,7 @@ "expected the same element type"); assert(newType.getNumElements() == curType.getNumElements() && "expected the same number of elements"); - return getRaw(newType, getRawData(), isSplat()); + return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); } DenseElementsAttr @@ -858,6 +840,63 @@ } //===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +DenseStringElementsAttr +DenseStringElementsAttr::get(ShapedType type, ArrayRef values) { + return Base::get(type.getContext(), StandardAttributes::DenseStringElements, + type, values, (values.size() == 1)); +} + +//===----------------------------------------------------------------------===// +// DenseIntOrFPElementsAttr +//===----------------------------------------------------------------------===// + +/// Constructs a dense elements attribute from an array of raw APInt values. +/// Each APInt value is expected to have the same bitwidth as the element type +/// of 'type'. +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, + ArrayRef values) { + assert(hasSameElementsOrSplat(type, values)); + + size_t bitWidth = getDenseElementBitWidth(type.getElementType()); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); + std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * + values.size()); + for (unsigned i = 0, e = values.size(); i != e; ++i) { + assert(values[i].getBitWidth() == bitWidth); + writeBits(elementData.data(), i * storageBitWidth, values[i]); + } + return DenseIntOrFPElementsAttr::getRaw(type, elementData, + /*isSplat=*/(values.size() == 1)); +} + +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, + ArrayRef data, + bool isSplat) { + assert((type.isa() || type.isa()) && + "type must be ranked tensor or vector"); + assert(type.hasStaticShape() && "type must have static shape"); + return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, + type, data, isSplat); +} + +/// Overload of the 'getRaw' method that asserts that the given type is of +/// integer type. This method is used to verify type invariants that the +/// templatized 'get' method cannot. +DenseElementsAttr +DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned) { + assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); + + int64_t numElements = data.size() / dataEltSize; + assert(numElements == 1 || numElements == type.getNumElements()); + return getRaw(type, data, /*isSplat=*/numElements == 1); +} + +//===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -82,10 +82,11 @@ /// the IR. struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) { - addAttributes(); + addAttributes(); addAttributes(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1953,7 +1953,7 @@ ArrayRef getShape() const { return shape; } private: - enum class ElementKind { Boolean, Integer, Float }; + enum class ElementKind { Boolean, Integer, Float, String }; /// Return a string to represent the given element kind. const char *getElementKindStr(ElementKind kind) { @@ -1964,6 +1964,8 @@ return "'integer'"; case ElementKind::Float: return "'float'"; + case ElementKind::String: + return "'string'"; } llvm_unreachable("unknown element kind"); } @@ -1975,6 +1977,9 @@ DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy); + /// Build a Dense String attribute for the given type. + DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); + /// Build a Dense attribute with hex data for the given type. DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); @@ -2030,8 +2035,10 @@ /// shaped type. DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, ShapedType type) { - // Check to see if we parsed the literal from a hex string. - if (hexStorage.hasValue()) + Type eltType = type.getElementType(); + + // Check to see if we parse the literal from a hex string. + if (hexStorage.hasValue() && eltType.isIntOrFloat()) return getHexAttr(loc, type); // Check that the parsed storage size has the same number of elements to the @@ -2044,20 +2051,17 @@ // If the type is an integer, build a set of APInt values from the storage // with the correct bitwidth. - Type eltType = type.getElementType(); if (auto intTy = eltType.dyn_cast()) return getIntAttr(loc, type, intTy); if (auto indexTy = eltType.dyn_cast()) return getIntAttr(loc, type, indexTy); - // Otherwise, this must be a floating point type. - auto floatTy = eltType.dyn_cast(); - if (!floatTy) { - p.emitError(loc) << "expected floating-point or integer element type, got " - << eltType; - return nullptr; - } - return getFloatAttr(loc, type, floatTy); + // If parsing a floating point type. + if (auto floatTy = eltType.dyn_cast()) + return getFloatAttr(loc, type, floatTy); + + // Other types are assumed to be string representations. + return getStringAttr(loc, type, type.getElementType()); } /// Build a Dense Integer attribute for the given type. @@ -2163,6 +2167,28 @@ return DenseElementsAttr::get(type, floatValues); } +/// Build a Dense String attribute for the given type. +DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, + ShapedType type, + Type eltTy) { + if (hexStorage.hasValue()) { + auto stringValue = hexStorage.getValue().getStringValue(); + return DenseStringElementsAttr::get(type, {stringValue}); + } + + std::vector stringValues; + std::vector stringRefValues; + stringValues.reserve(storage.size()); + stringRefValues.reserve(storage.size()); + + for (auto val : storage) { + stringValues.push_back(val.second.getStringValue()); + stringRefValues.push_back(stringValues.back()); + } + + return DenseStringElementsAttr::get(type, stringRefValues); +} + /// Build a Dense attribute with hex data for the given type. DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, ShapedType type) { @@ -2214,6 +2240,10 @@ p.consumeToken(); break; + case Token::string: + storage.emplace_back(/*isNegative=*/ false, p.getToken()); + p.consumeToken(); + break; default: return p.emitError("expected element literal of primitive type"); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -391,6 +391,40 @@ // ----- //===----------------------------------------------------------------------===// +// Test StringElementsAttr +//===----------------------------------------------------------------------===// + +func @simple_scalar_example() { + "test.string_elements_attr"() { + // CHECK: dense<"example"> + scalar_string_attr = dense<"example"> : tensor<2x!unknown<"">> + } : () -> () + return +} + +// ----- + +func @escape_string_example() { + "test.string_elements_attr"() { + // CHECK: dense<"new\0Aline"> + scalar_string_attr = dense<"new\nline"> : tensor<2x!unknown<"">> + } : () -> () + return +} + +// ----- + +func @simple_scalar_example() { + "test.string_elements_attr"() { + // CHECK: dense<["example1", "example2"]> + scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!unknown<"">> + } : () -> () + return +} + +// ----- + +//===----------------------------------------------------------------------===// // Test SymbolRefAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -22,10 +22,5 @@ // ----- -// expected-error@+1 {{expected floating-point or integer element type, got '!unknown<"">'}} -"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> () - -// ----- - // expected-error@+1 {{elements hex data size is invalid for provided type}} "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -245,6 +245,12 @@ "$_builder.getI32IntegerAttr($_self)">; } +def StringElementsAttrOp : TEST_Op<"string_elements_attr"> { + let arguments = (ins + StringElementsAttr:$scalar_string_attr + ); +} + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===//