diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -142,11 +142,11 @@ Unit, /// Elements Attributes. - DenseElements, + DenseIntOrFPElements, DenseStringElements, OpaqueElements, SparseElements, - FIRST_ELEMENTS_ATTR = DenseElements, + FIRST_ELEMENTS_ATTR = DenseIntOrFPElements, LAST_ELEMENTS_ATTR = SparseElements, /// Locations. @@ -679,7 +679,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { - return attr.getKind() == StandardAttributes::DenseElements || + return attr.getKind() == StandardAttributes::DenseIntOrFPElements || attr.getKind() == StandardAttributes::DenseStringElements; } @@ -887,6 +887,11 @@ ElementIterator(rawData, splat, getNumElements())}; } + llvm::iterator_range::iterator> getValues() const { + auto stringRefs = getStringValues(); + return {stringRefs.begin(), stringRefs.end()}; + } + /// Return the held element values as a range of Attributes. llvm::iterator_range getAttributeValues() const; template getRawData() const; + /// Return the raw StringRef data held by this attribute. + ArrayRef getStringRefs() const; + //===--------------------------------------------------------------------===// // Mutation Utilities //===--------------------------------------------------------------------===// @@ -992,6 +1000,8 @@ bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; }; +// An attribute class for representing dense arrays of strings. The structure +// storing and querying a list of densely packed strings. class DenseStringElementsAttr : public Attribute::AttrBase { @@ -1004,9 +1014,6 @@ return kind == StandardAttributes::DenseStringElements; } - /// Return the held element values as a range of StringRefs. - ArrayRef getStringValues() const; - /// Overload of the raw 'get' method that asserts that the given type is of /// integer or floating-point type. This method is used to verify type /// invariants that the templatized 'get' method cannot. @@ -1016,6 +1023,8 @@ friend DenseElementsAttr; }; +// An attribute class for specializing behavior of Int and Floating-point +// densely packed string arrays. class DenseIntOrFPElementsAttr : public Attribute::AttrBase { @@ -1025,7 +1034,7 @@ /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(unsigned kind) { - return kind == StandardAttributes::DenseElements; + return kind == StandardAttributes::DenseIntOrFPElements; } protected: @@ -1319,7 +1328,7 @@ typename... Args> RetT process(Args &... args) const { switch (attrKind) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return ProcessFn()(args...); case StandardAttributes::SparseElements: return ProcessFn()(args...); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1316,7 +1316,7 @@ << opType << ") does not match value type (" << valueType << ")"; return success(); } break; - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: case StandardAttributes::SparseElements: { if (valueType == opType) break; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1395,7 +1395,7 @@ os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; break; } - case StandardAttributes::DenseElements: { + case StandardAttributes::DenseIntOrFPElements: { auto eltsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(eltsAttr)) { printElidedElementsAttr(os); @@ -1554,7 +1554,6 @@ // Special case for 0-d and splat tensors. if (attr.isSplat()) { - // TODO(suderman): Add a print for a single value. printDenseStringElement(attr, os, 0); return; } diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -586,9 +586,8 @@ // If the data is already known to be a splat, the key hash value is // directly the data buffer. - if (isKnownSplat) { + if (isKnownSplat) return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); - } // Handle the simple case of only one element. size_t numElements = ty.getNumElements(); @@ -628,10 +627,9 @@ // Compute the amount data needed to store the ArrayRef and StringRef // contents. - size_t dataSize = sizeof(ArrayRef) * numEntries; - for (int i = 0; i < numEntries; i++) { + size_t dataSize = sizeof(StringRef) * numEntries; + for (int i = 0; i < numEntries; i++) dataSize += data[i].size(); - } char *rawData = reinterpret_cast( allocator.allocate(dataSize, alignof(uint64_t))); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -411,7 +411,7 @@ /// element, then a null attribute is returned. Attribute ElementsAttr::getValue(ArrayRef index) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().getValue(index); case StandardAttributes::OpaqueElements: return cast().getValue(index); @@ -442,7 +442,7 @@ ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().mapValues(newElementType, mapping); default: llvm_unreachable("unsupported ElementsAttr subtype"); @@ -453,7 +453,7 @@ ElementsAttr::mapValues(Type newElementType, function_ref mapping) const { switch (getKind()) { - case StandardAttributes::DenseElements: + case StandardAttributes::DenseIntOrFPElements: return cast().mapValues(newElementType, mapping); default: llvm_unreachable("unsupported ElementsAttr subtype"); @@ -814,6 +814,10 @@ return static_cast(impl)->data; } +ArrayRef DenseElementsAttr::getStringRefs() const { + return static_cast(impl)->data; +} + /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. @@ -851,9 +855,6 @@ type, values, (values.size() == 1)); } -ArrayRef DenseStringElementsAttr::getStringValues() const { - return static_cast(impl)->data; -} //===----------------------------------------------------------------------===// // DenseIntOrFPElementsAttr @@ -884,8 +885,8 @@ assert((type.isa() || type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::DenseElements, type, - data, isSplat); + return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, + type, data, isSplat); } /// Overload of the 'getRaw' method that asserts that the given type is of