diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -40,7 +40,8 @@ struct TypeAttributeStorage; /// Elements Attributes. -struct DenseElementsAttributeStorage; +struct DenseIntOrFPElementsAttributeStorage; +struct DenseStringElementsAttributeStorage; struct OpaqueElementsAttributeStorage; struct SparseElementsAttributeStorage; } // namespace detail @@ -142,6 +143,7 @@ /// Elements Attributes. DenseElements, + DenseStringElements, OpaqueElements, SparseElements, FIRST_ELEMENTS_ATTR = DenseElements, @@ -671,15 +673,14 @@ /// An attribute that represents a reference to a dense vector or tensor object. /// -class DenseElementsAttr - : public Attribute::AttrBase { +class DenseElementsAttr : public ElementsAttr { public: - using Base::Base; + using ElementsAttr::ElementsAttr; /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { - return attr.getKind() == StandardAttributes::DenseElements; + return attr.getKind() == StandardAttributes::DenseElements || + attr.getKind() == StandardAttributes::DenseStringElements; } /// Constructs a dense elements attribute from an array of element values. @@ -712,6 +713,10 @@ /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Overload of the above 'get' method that is specialized for StringRef + /// values. + static DenseElementsAttr get(ShapedType type, ArrayRef values); + /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static @@ -973,6 +978,59 @@ return IntElementIterator(*this, getNumElements()); } + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseElementsAttr getRawIntOrFloat(ShapedType type, + ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + /// Check the information for a C++ data type, check if this type is valid for + /// the current attribute. This method is used to verify specific type + /// invariants that the templatized 'getValues' method cannot. + bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; +}; + +class DenseStringElementsAttr + : public Attribute::AttrBase { + +public: + using Base::Base; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::DenseStringElements; + } + + /// Return the held element values as a range of StringRefs. + ArrayRef getStringValues() const; + + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseStringElementsAttr get(ShapedType type, ArrayRef data); + +protected: + friend DenseElementsAttr; +}; + +class DenseIntOrFPElementsAttr + : public Attribute::AttrBase { + +public: + using Base::Base; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardAttributes::DenseElements; + } + +protected: + friend DenseElementsAttr; + /// Constructs a dense elements attribute from an array of raw APInt values. /// Each APInt value is expected to have the same bitwidth as the element type /// of 'type'. 'type' must be a vector or tensor with static shape. @@ -990,20 +1048,15 @@ ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned); - - /// Check the information for a C++ data type, check if this type is valid for - /// the current attribute. This method is used to verify specific type - /// invariants that the templatized 'getValues' method cannot. - bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; }; /// An attribute that represents a reference to a dense float vector or tensor /// object. Each element is stored as a double. -class DenseFPElementsAttr : public DenseElementsAttr { +class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { public: using iterator = DenseElementsAttr::FloatElementIterator; - using DenseElementsAttr::DenseElementsAttr; + using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; /// Get an instance of a DenseFPElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. @@ -1035,13 +1088,13 @@ /// An attribute that represents a reference to a dense integer vector or tensor /// object. -class DenseIntElementsAttr : public DenseElementsAttr { +class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { public: /// DenseIntElementsAttr iterates on APInt, so we can use the raw element /// iterator directly. using iterator = DenseElementsAttr::IntElementIterator; - using DenseElementsAttr::DenseElementsAttr; + using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; /// Get an instance of a DenseIntElementsAttr with the given arguments. This /// simply wraps the DenseElementsAttr::get calls. diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -25,6 +25,7 @@ DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect +DEFINE_SYM_KIND_RANGE(TEST) // Test dialect // The following ranges are reserved for experimenting with MLIR dialects in a // private context without having to register them here. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1307,6 +1307,16 @@ class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; +def StringElementsAttr : ElementsAttrBase< + CPred<"$_self.isa()" >, + "string elements attribute"> { + + let storageType = [{ DenseElementsAttr }]; + let returnType = [{ DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + // Base class for array attributes. class ArrayAttrBase : Attr { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -973,6 +973,9 @@ /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); + /// Print a dense string elements attribute. + void printDenseStringElementsAttr(DenseStringElementsAttr attr); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1403,6 +1406,17 @@ os << '>'; break; } + case StandardAttributes::DenseStringElements: { + auto eltsAttr = attr.cast(); + if (printerFlags.shouldElideElementsAttr(eltsAttr)) { + printElidedElementsAttr(os); + break; + } + os << "dense<"; + printDenseStringElementsAttr(eltsAttr); + os << '>'; + break; + } case StandardAttributes::SparseElements: { auto elementsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) || @@ -1454,6 +1468,13 @@ printFloatValue(value, os); } +static void printDenseStringElement(DenseStringElementsAttr attr, + raw_ostream &os, unsigned index) { + os << "\""; + printEscapedString(attr.getStringValues()[index], os); + os << "\""; +} + void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { auto type = attr.getType(); @@ -1526,6 +1547,64 @@ os << ']'; } +void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) { + auto type = attr.getType(); + auto shape = type.getShape(); + auto rank = type.getRank(); + + // Special case for 0-d and splat tensors. + if (attr.isSplat()) { + // TODO(suderman): Add a print for a single value. + printDenseStringElement(attr, os, 0); + return; + } + + // Special case for degenerate tensors. + auto numElements = type.getNumElements(); + if (numElements == 0) { + for (int i = 0; i < rank; ++i) + os << '['; + for (int i = 0; i < rank; ++i) + os << ']'; + return; + } + + // We use a mixed-radix counter to iterate through the shape. When we bump a + // non-least-significant digit, we emit a close bracket. When we next emit an + // element we re-open all closed brackets. + + // The mixed-radix counter, with radices in 'shape'. + SmallVector counter(rank, 0); + // The number of brackets that have been opened and not closed. + unsigned openBrackets = 0; + + auto bumpCounter = [&]() { + // Bump the least significant digit. + ++counter[rank - 1]; + // Iterate backwards bubbling back the increment. + for (unsigned i = rank - 1; i > 0; --i) + if (counter[i] >= shape[i]) { + // Index 'i' is rolled over. Bump (i-1) and close a bracket. + counter[i] = 0; + ++counter[i - 1]; + --openBrackets; + os << ']'; + } + }; + + for (unsigned idx = 0, e = numElements; idx != e; ++idx) { + if (idx != 0) + os << ", "; + while (openBrackets++ < rank) + os << '['; + openBrackets = rank; + printDenseStringElement(attr, os, idx); + bumpCounter(); + } + while (openBrackets-- > 0) + os << ']'; +} + void ModulePrinter::printType(Type type) { if (!type) { os << "<>"; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -372,8 +372,21 @@ // Elements Attributes //===----------------------------------------------------------------------===// -/// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { +public: + DenseElementsAttributeStorage(ShapedType ty, bool isSplat) + : AttributeStorage(ty), isSplat(isSplat) {} + + bool isSplat; +}; + +/// An attribute representing a reference to a dense vector or tensor object. +struct DenseIntOrFPElementsAttributeStorage + : public DenseElementsAttributeStorage { + DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) + : DenseElementsAttributeStorage(ty, isSplat), data(data) {} + struct KeyTy { KeyTy(ShapedType type, ArrayRef data, llvm::hash_code hashCode, bool isSplat = false) @@ -392,10 +405,6 @@ bool isSplat; }; - DenseElementsAttributeStorage(ShapedType ty, ArrayRef data, - bool isSplat = false) - : AttributeStorage(ty), data(data), isSplat(isSplat) {} - /// Compare this storage instance with the provided key. bool operator==(const KeyTy &key) const { if (key.type != getType()) @@ -506,7 +515,7 @@ } /// Construct a new storage instance. - static DenseElementsAttributeStorage * + static DenseIntOrFPElementsAttributeStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { // If the data buffer is non-empty, we copy it into the allocator with a // 64-bit alignment. @@ -522,12 +531,127 @@ copy = ArrayRef(rawData, data.size()); } - return new (allocator.allocate()) - DenseElementsAttributeStorage(key.type, copy, key.isSplat); + return new (allocator.allocate()) + DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat); } ArrayRef data; - bool isSplat; +}; + +/// An attribute representing a reference to a dense vector or tensor object +/// containing strings. +struct DenseStringElementsAttributeStorage + : public DenseElementsAttributeStorage { + DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) + : DenseElementsAttributeStorage(ty, isSplat), data(data) {} + + struct KeyTy { + KeyTy(ShapedType type, ArrayRef data, llvm::hash_code hashCode, + bool isSplat = false) + : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {} + + /// The type of the dense elements. + ShapedType type; + + /// The raw buffer for the data storage. + ArrayRef data; + + /// The computed hash code for the storage data. + llvm::hash_code hashCode; + + /// A boolean that indicates if this data is a splat or not. + bool isSplat; + }; + + /// Compare this storage instance with the provided key. + bool operator==(const KeyTy &key) const { + if (key.type != getType()) + return false; + + // Otherwise, we can default to just checking the data. StringRefs compare + // by contents. + return key.data == data; + } + + /// Construct a key from a shaped type, StringRef data buffer, and a flag that + /// signals if the data is already known to be a splat. Callers to this + /// function are expected to tag preknown splat values when possible, e.g. one + /// element shapes. + static KeyTy getKey(ShapedType ty, ArrayRef data, + bool isKnownSplat) { + // Handle an empty storage instance. + if (data.empty()) + return KeyTy(ty, data, 0); + + // If the data is already known to be a splat, the key hash value is + // directly the data buffer. + if (isKnownSplat) { + return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); + } + + // Handle the simple case of only one element. + size_t numElements = ty.getNumElements(); + assert(numElements != 1 && "splat of 1 element should already be detected"); + + // Create the initial hash value with just the first element. + const auto &firstElt = data.front(); + auto hashVal = llvm::hash_value(firstElt); + + // Check to see if this storage represents a splat. If it doesn't then + // combine the hash for the data starting with the first non splat element. + for (size_t i = 1, e = data.size(); i != e; i++) + if (!firstElt.equals(data[i])) + return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i))); + + // Otherwise, this is a splat so just return the hash of the first element. + return KeyTy(ty, {firstElt}, hashVal, /*isSplat=*/true); + } + + /// Hash the key for the storage. + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_combine(key.type, key.hashCode); + } + + /// Construct a new storage instance. + static DenseStringElementsAttributeStorage * + construct(AttributeStorageAllocator &allocator, KeyTy key) { + // If the data buffer is non-empty, we copy it into the allocator with a + // 64-bit alignment. + ArrayRef copy, data = key.data; + if (!data.empty()) { + int numEntries = key.isSplat ? 1 : data.size(); + + // Compute the amount data needed to store the ArrayRef and StringRef + // contents. + size_t dataSize = sizeof(ArrayRef) * numEntries; + for (int i = 0; i < numEntries; i++) { + dataSize += data[i].size(); + } + + char *rawData = reinterpret_cast( + allocator.allocate(dataSize, alignof(uint64_t))); + + // Setup the ArrayRef + auto mutable_copy = MutableArrayRef( + reinterpret_cast(rawData), numEntries); + auto stringData = rawData + numEntries * sizeof(StringRef); + + for (int i = 0; i < numEntries; i++) { + memcpy(stringData, data[i].data(), data[i].size()); + mutable_copy[i] = StringRef(stringData, data[i].size()); + stringData += data[i].size(); + } + + copy = ArrayRef(reinterpret_cast(rawData), + numEntries); + } + + return new (allocator.allocate()) + DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + } + + ArrayRef data; }; /// An attribute representing a reference to a tensor constant with opaque diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -648,7 +648,8 @@ "expected value to have same bitwidth as element type"); writeBits(data.data(), i * storageBitWidth, intVal); } - return getRaw(type, data, /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, data, + /*isSplat=*/(values.size() == 1)); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -659,7 +660,14 @@ std::vector buff(llvm::divideCeil(values.size(), CHAR_BIT)); for (int i = 0, e = values.size(); i != e; ++i) setBit(buff.data(), i, values[i]); - return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); + return DenseIntOrFPElementsAttr::getRaw(type, buff, + /*isSplat=*/(values.size() == 1)); +} + +DenseElementsAttr DenseElementsAttr::get(ShapedType type, + ArrayRef values) { + assert(!type.getElementType().isIntOrFloat()); + return DenseStringElementsAttr::get(type, values); } /// Constructs a dense integer elements attribute from an array of APInt @@ -668,7 +676,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { assert(type.getElementType().isa()); - return getRaw(type, values); + return DenseIntOrFPElementsAttr::getRaw(type, values); } // Constructs a dense float elements attribute from an array of APFloat @@ -682,7 +690,7 @@ std::vector intValues(values.size()); for (unsigned i = 0, e = values.size(); i != e; ++i) intValues[i] = values[i].bitcastToAPInt(); - return getRaw(type, intValues); + return DenseIntOrFPElementsAttr::getRaw(type, intValues); } /// Construct a dense elements attribute from a raw buffer representing the @@ -691,35 +699,9 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef rawBuffer, bool isSplatBuffer) { - return getRaw(type, rawBuffer, isSplatBuffer); + return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer); } -/// Constructs a dense elements attribute from an array of raw APInt values. -/// Each APInt value is expected to have the same bitwidth as the element type -/// of 'type'. -DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, - ArrayRef values) { - assert(hasSameElementsOrSplat(type, values)); - - size_t bitWidth = getDenseElementBitwidth(type.getElementType()); - size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * - values.size()); - for (unsigned i = 0, e = values.size(); i != e; ++i) { - assert(values[i].getBitWidth() == bitWidth); - writeBits(elementData.data(), i * storageBitWidth, values[i]); - } - return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); -} - -DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, - ArrayRef data, bool isSplat) { - assert((type.isa() || type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::DenseElements, type, - data, isSplat); -} /// Check the information for a C++ data type, check if this type is valid for /// the current attribute. This method is used to verify specific type @@ -745,19 +727,14 @@ return intType.isSigned() ? isSigned : !isSigned; } -/// Overload of the 'getRaw' method that asserts that the given type is of -/// integer type. This method is used to verify type invariants that the -/// templatized 'get' method cannot. +/// Defaults down the subclass implementation. DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { - assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); - - int64_t numElements = data.size() / dataEltSize; - assert(numElements == 1 || numElements == type.getNumElements()); - return getRaw(type, data, /*isSplat=*/numElements == 1); + return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, + isInt, isSigned); } /// A method used to verify specific type invariants that the templatized 'get' @@ -769,7 +746,9 @@ /// Returns if this attribute corresponds to a splat, i.e. if all element /// values are the same. -bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } +bool DenseElementsAttr::isSplat() const { + return static_cast(impl)->isSplat; +} /// Return the held element values as a range of Attributes. auto DenseElementsAttr::getAttributeValues() const @@ -832,7 +811,7 @@ /// Return the raw storage data held by this attribute. ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; + return static_cast(impl)->data; } /// Return a new DenseElementsAttr that has the same data as the current @@ -848,7 +827,7 @@ "expected the same element type"); assert(newType.getNumElements() == curType.getNumElements() && "expected the same number of elements"); - return getRaw(newType, getRawData(), isSplat()); + return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); } DenseElementsAttr @@ -862,6 +841,67 @@ return cast().mapValues(newElementType, mapping); } +//===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +DenseStringElementsAttr +DenseStringElementsAttr::get(ShapedType type, ArrayRef values) { + return Base::get(type.getContext(), StandardAttributes::DenseStringElements, + type, values, (values.size() == 1)); +} + +ArrayRef DenseStringElementsAttr::getStringValues() const { + return static_cast(impl)->data; +} + +//===----------------------------------------------------------------------===// +// DenseIntOrFPElementsAttr +//===----------------------------------------------------------------------===// + +/// Constructs a dense elements attribute from an array of raw APInt values. +/// Each APInt value is expected to have the same bitwidth as the element type +/// of 'type'. +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, + ArrayRef values) { + assert(hasSameElementsOrSplat(type, values)); + + size_t bitWidth = getDenseElementBitwidth(type.getElementType()); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); + std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * + values.size()); + for (unsigned i = 0, e = values.size(); i != e; ++i) { + assert(values[i].getBitWidth() == bitWidth); + writeBits(elementData.data(), i * storageBitWidth, values[i]); + } + return DenseIntOrFPElementsAttr::getRaw(type, elementData, + /*isSplat=*/(values.size() == 1)); +} + +DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, + ArrayRef data, + bool isSplat) { + assert((type.isa() || type.isa()) && + "type must be ranked tensor or vector"); + assert(type.hasStaticShape() && "type must have static shape"); + return Base::get(type.getContext(), StandardAttributes::DenseElements, type, + data, isSplat); +} + +/// Overload of the 'getRaw' method that asserts that the given type is of +/// integer type. This method is used to verify type invariants that the +/// templatized 'get' method cannot. +DenseElementsAttr +DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned) { + assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); + + int64_t numElements = data.size() / dataEltSize; + assert(numElements == 1 || numElements == type.getNumElements()); + return getRaw(type, data, /*isSplat=*/numElements == 1); +} + //===----------------------------------------------------------------------===// // DenseFPElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -82,10 +82,11 @@ /// the IR. struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) { - addAttributes(); + addAttributes(); addAttributes(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1952,7 +1952,7 @@ ArrayRef getShape() const { return shape; } private: - enum class ElementKind { Boolean, Integer, Float }; + enum class ElementKind { Boolean, Integer, Float, String }; /// Return a string to represent the given element kind. const char *getElementKindStr(ElementKind kind) { @@ -1963,6 +1963,8 @@ return "'integer'"; case ElementKind::Float: return "'float'"; + case ElementKind::String: + return "'string'"; } llvm_unreachable("unknown element kind"); } @@ -1975,6 +1977,9 @@ DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy); + /// Build a Dense String attribute for the given type. + DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); + /// Build a Dense attribute with hex data for the given type. DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); @@ -2030,8 +2035,10 @@ /// shaped type. DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, ShapedType type) { - // Check to see if we parsed the literal from a hex string. - if (hexStorage.hasValue()) + Type eltType = type.getElementType(); + + // Check to see if we parse the literal from a hex string. + if (hexStorage.hasValue() && eltType.isIntOrFloat()) return getHexAttr(loc, type); // Check that the parsed storage size has the same number of elements to the @@ -2048,13 +2055,12 @@ return getIntAttr(loc, type, intTy); // Otherwise, this must be a floating point type. - auto floatTy = type.getElementType().dyn_cast(); - if (!floatTy) { - p.emitError(loc) << "expected floating-point or integer element type, got " - << type.getElementType(); - return nullptr; + if (auto floatTy = type.getElementType().dyn_cast()) { + return getFloatAttr(loc, type, floatTy); } - return getFloatAttr(loc, type, floatTy); + + // Other types are assumed to be string representations. + return getStringAttr(loc, type, type.getElementType()); } /// Build a Dense Integer attribute for the given type. @@ -2160,6 +2166,28 @@ return DenseElementsAttr::get(type, floatValues); } +/// Build a Dense String attribute for the given type. +DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc, + ShapedType type, + Type eltTy) { + if (hexStorage.hasValue()) { + auto stringValue = hexStorage.getValue().getStringValue(); + return DenseStringElementsAttr::get(type, {stringValue}); + } + + std::vector stringValues; + std::vector stringRefValues; + stringValues.reserve(storage.size()); + stringRefValues.reserve(storage.size()); + + for (auto val : storage) { + stringValues.push_back(val.second.getStringValue()); + stringRefValues.push_back(stringValues.back()); + } + + return DenseStringElementsAttr::get(type, stringRefValues); +} + /// Build a Dense attribute with hex data for the given type. DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, ShapedType type) { @@ -2211,6 +2239,10 @@ p.consumeToken(); break; + case Token::string: + storage.emplace_back(/*isNegative*/ false, p.getToken()); + p.consumeToken(); + break; default: return p.emitError("expected element literal of primitive type"); } @@ -2266,13 +2298,16 @@ /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { consumeToken(Token::kw_dense); - if (parseToken(Token::less, "expected '<' after 'dense'")) + + if (parseToken(Token::less, "expected '<' after 'dense'")) { return nullptr; + } // Parse the literal data. TensorLiteralParser literalParser(*this); - if (literalParser.parse(/*allowHex=*/true)) + if (literalParser.parse(/*allowHex=*/true)) { return nullptr; + } if (parseToken(Token::greater, "expected '>'")) return nullptr; diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -390,6 +390,40 @@ // ----- +//===----------------------------------------------------------------------===// +// Test StringElementsAttr +//===----------------------------------------------------------------------===// + +func @simple_scalar_example() { + "test.string_elements_attr"() { + // CHECK: dense<"example"> + scalar_string_attr = dense<"example"> : tensor<2x!test.custom_type> + } : () -> () + return +} + +// ----- + +func @escape_string_example() { + "test.string_elements_attr"() { + // CHECK: dense<"new\0Aline"> + scalar_string_attr = dense<"new\nline"> : tensor<2x!test.custom_type> + } : () -> () + return +} + +// ----- + +func @simple_scalar_example() { + "test.string_elements_attr"() { + // CHECK: dense<["example1", "example2"]> + scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!test.custom_type> + } : () -> () + return +} + +// ----- + //===----------------------------------------------------------------------===// // Test SymbolRefAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -22,10 +22,5 @@ // ----- -// expected-error@+1 {{expected floating-point or integer element type, got '!unknown<"">'}} -"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> () - -// ----- - // expected-error@+1 {{elements hex data size is invalid for provided type}} "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -30,8 +30,38 @@ namespace mlir { +namespace TestTypes { +enum Kind { + FIRST_USED_TEST_TYPE = Type::FIRST_TEST_TYPE, + CustomTestType, + LAST_USED_TEST_TYPE +}; +} // namespace TestTypes + #include "TestOpsDialect.h.inc" +class TestType : public Type { +public: + using Type::Type; + + static bool classof(Type type) { + return type.getKind() >= TestTypes::FIRST_USED_TEST_TYPE && + type.getKind() <= TestTypes::LAST_USED_TEST_TYPE; + } +}; + +class CustomTestType : public Type::TypeBase { +public: + using Base::Base; + static CustomTestType get(MLIRContext *context) { + return Base::get(context, TestTypes::CustomTestType); + } + + static bool kindof(unsigned kind) { + return kind == TestTypes::CustomTestType; + } +}; + #define GET_OP_CLASSES #include "TestOps.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" @@ -129,6 +130,8 @@ TestDialect::TestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { + addTypes(); + addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" @@ -163,6 +166,27 @@ return success(); } +Type TestDialect::parseType(DialectAsmParser &parser) const { + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); + llvm::StringRef spec = parser.getFullSymbolSpec(); + if (spec == "custom_type") { + return CustomTestType::get(getContext()); + } + emitError(loc, "unknown TestDialect type:") << spec; + + return Type(); +} + +void TestDialect::printType(Type type, DialectAsmPrinter &os) const { + switch (type.getKind()) { + case TestTypes::CustomTestType: + os << "custom_type"; + break; + default: + llvm_unreachable("unhandle test dialect type"); + } +} + //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -32,6 +32,14 @@ // Test Types //===----------------------------------------------------------------------===// +def Test_CustomTestType : DialectType< + Test_Dialect, + CPred<"$_self.isa()">, "custom_type"> { + let typeDescription = [{ + Custom type example. + }]; +} + def IntTypesOp : TEST_Op<"int_types"> { let results = (outs AnyI16:$any_i16, @@ -245,6 +253,12 @@ "$_builder.getI32IntegerAttr($_self)">; } +def StringElementsAttrOp : TEST_Op<"string_elements_attr"> { + let arguments = (ins + StringElementsAttr:$scalar_string_attr + ); +} + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===//