diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -723,6 +723,13 @@ return get(type, ArrayRef(list)); } + /// Construct a dense elements attribute from a raw buffer representing the + /// data for this attribute. Users should generally not use this methods as + /// the expected buffer format may not be a form the user expects. + static DenseElementsAttr getFromRawBuffer(ShapedType type, + ArrayRef rawBuffer, + bool isSplatBuffer); + //===--------------------------------------------------------------------===// // Iterators //===--------------------------------------------------------------------===// @@ -918,6 +925,11 @@ FloatElementIterator float_value_begin() const; FloatElementIterator float_value_end() const; + /// Return the raw storage data held by this attribute. Users should generally + /// not use this directly, as the internal storage format is not always in the + /// form the user might expect. + ArrayRef getRawData() const; + //===--------------------------------------------------------------------===// // Mutation Utilities //===--------------------------------------------------------------------===// @@ -941,9 +953,6 @@ function_ref mapping) const; protected: - /// Return the raw storage data held by this attribute. - ArrayRef getRawData() const; - /// Get iterators to the raw APInt values for each element in this attribute. IntElementIterator raw_int_begin() const { return IntElementIterator(*this, 0); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -63,6 +63,13 @@ // OpPrintingFlags //===----------------------------------------------------------------------===// +static llvm::cl::opt printElementsAttrWithHexIfLarger( + "mlir-print-elementsattrs-with-hex-if-larger", + llvm::cl::desc( + "Print DenseElementsAttrs with a hex string that have " + "more elements than the given upper limit (use -1 to disable)"), + llvm::cl::init(100)); + static llvm::cl::opt elideElementsAttrIfLarger( "mlir-elide-elementsattrs-if-larger", llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " @@ -887,7 +894,10 @@ bool withKeyword = false); void printTrailingLocation(Location loc); void printLocationInternal(LocationAttr loc, bool pretty = false); - void printDenseElementsAttr(DenseElementsAttr attr); + + /// Print a dense elements attribute. If 'allowHex' is true, a hex string is + /// used instead of individual elements when the elements attr is large. + void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1321,7 +1331,7 @@ break; } os << "dense<"; - printDenseElementsAttr(eltsAttr); + printDenseElementsAttr(eltsAttr, /*allowHex=*/true); os << '>'; break; } @@ -1333,9 +1343,9 @@ break; } os << "sparse<"; - printDenseElementsAttr(elementsAttr.getIndices()); + printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false); os << ", "; - printDenseElementsAttr(elementsAttr.getValues()); + printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); os << '>'; break; } @@ -1375,7 +1385,8 @@ printFloatValue(value, os); } -void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { +void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, + bool allowHex) { auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); @@ -1401,6 +1412,15 @@ return; } + // Check to see if we should format this attribute as a hex string. + if (allowHex && printElementsAttrWithHexIfLarger != -1 && + numElements > printElementsAttrWithHexIfLarger) { + ArrayRef rawData = attr.getRawData(); + os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) + << "\""; + return; + } + // We use a mixed-radix counter to iterate through the shape. When we bump a // non-least-significant digit, we emit a close bracket. When we next emit an // element we re-open all closed brackets. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -664,9 +664,18 @@ return getRaw(type, intValues); } -// Constructs a dense elements attribute from an array of raw APInt values. -// Each APInt value is expected to have the same bitwidth as the element type -// of 'type'. +/// Construct a dense elements attribute from a raw buffer representing the +/// data for this attribute. Users should generally not use this methods as +/// the expected buffer format may not be a form the user expects. +DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type, + ArrayRef rawBuffer, + bool isSplatBuffer) { + return getRaw(type, rawBuffer, isSplatBuffer); +} + +/// Constructs a dense elements attribute from an array of raw APInt values. +/// Each APInt value is expected to have the same bitwidth as the element type +/// of 'type'. DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); @@ -727,11 +736,6 @@ return ::isValidIntOrFloat(getType(), dataEltSize, isInt); } -/// Return the raw storage data held by this attribute. -ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; -} - /// Returns if this attribute corresponds to a splat, i.e. if all element /// values are the same. bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } @@ -795,6 +799,11 @@ return getFloatValues().end(); } +/// Return the raw storage data held by this attribute. +ArrayRef DenseElementsAttr::getRawData() const { + return static_cast(impl)->data; +} + /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1808,6 +1808,24 @@ return builder.getIntegerAttr(type, isNegative ? -apInt : apInt); } +/// Parse elements values stored within a hex etring. On success, the values are +/// stored into 'result'. +static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, + std::string &result) { + std::string val = tok.getStringValue(); + if (val.size() < 2 || val[0] != '0' || val[1] != 'x') + return parser.emitError(tok.getLoc(), + "elements hex string should start with '0x'"); + + StringRef hexValues = StringRef(val).drop_front(2); + if (!llvm::all_of(hexValues, llvm::isHexDigit)) + return parser.emitError(tok.getLoc(), + "elements hex string only contains hex digits"); + + result = llvm::fromHex(hexValues); + return success(); +} + /// Parse an opaque elements attribute. Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); @@ -1825,31 +1843,23 @@ if (!dialect) return (emitError("no registered dialect with namespace '" + name + "'"), nullptr); - consumeToken(Token::string); + if (parseToken(Token::comma, "expected ','")) return nullptr; - if (getToken().getKind() != Token::string) - return (emitError("opaque string should start with '0x'"), nullptr); - - auto val = getToken().getStringValue(); - if (val.size() < 2 || val[0] != '0' || val[1] != 'x') - return (emitError("opaque string should start with '0x'"), nullptr); - - val = val.substr(2); - if (!llvm::all_of(val, llvm::isHexDigit)) - return (emitError("opaque string only contains hex digits"), nullptr); - - consumeToken(Token::string); - if (parseToken(Token::greater, "expected '>'")) + Token hexTok = getToken(); + if (parseToken(Token::string, "elements hex string should start with '0x'") || + parseToken(Token::greater, "expected '>'")) return nullptr; - auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; - return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val)); + std::string data; + if (parseElementAttrHexValues(*this, hexTok, data)) + return nullptr; + return OpaqueElementsAttr::get(dialect, type, data); } namespace { @@ -1857,11 +1867,9 @@ public: TensorLiteralParser(Parser &p) : p(p) {} - ParseResult parse() { - if (p.getToken().is(Token::l_square)) - return parseList(shape); - return parseElement(); - } + /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser + /// may also parse a tensor literal that is store as a hex string. + ParseResult parse(bool allowHex); /// Build a dense attribute instance with the parsed elements and the given /// shaped type. @@ -1893,6 +1901,9 @@ DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, FloatType eltTy); + /// Build a Dense attribute with hex data for the given type. + DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type); + /// Parse a single element, returning failure if it isn't a valid element /// literal. For example: /// parseElement(1) -> Success, 1 @@ -1907,6 +1918,9 @@ /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure ParseResult parseList(SmallVectorImpl &dims); + /// Parse a literal that was printed as a hex string. + ParseResult parseHexElements(); + Parser &p; /// The shape inferred from the parsed elements. @@ -1917,13 +1931,35 @@ /// A flag that indicates the type of elements that have been parsed. Optional knownEltKind; + + /// Storage used when parsing elements that were stored as hex values. + Optional hexStorage; }; } // namespace +/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser +/// may also parse a tensor literal that is store as a hex string. +ParseResult TensorLiteralParser::parse(bool allowHex) { + // If hex is allowed, check for a string literal. + if (allowHex && p.getToken().is(Token::string)) { + hexStorage = p.getToken(); + p.consumeToken(Token::string); + return success(); + } + // Otherwise, parse a list or an individual element. + if (p.getToken().is(Token::l_square)) + return parseList(shape); + return parseElement(); +} + /// Build a dense attribute instance with the parsed elements and the given /// shaped type. DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc, ShapedType type) { + // Check to see if we parsed the literal from a hex string. + if (hexStorage.hasValue()) + return getHexAttr(loc, type); + // Check that the parsed storage size has the same number of elements to the // type, or is a known splat. if (!shape.empty() && getShape() != type.getShape()) { @@ -2045,6 +2081,33 @@ return DenseElementsAttr::get(type, floatValues); } +/// Build a Dense attribute with hex data for the given type. +DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc, + ShapedType type) { + Type elementType = type.getElementType(); + if (!elementType.isIntOrFloat()) { + p.emitError(loc) << "expected floating-point or integer element type, got " + << elementType; + return nullptr; + } + + std::string data; + if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) + return nullptr; + + // Check that the size of the hex data correpsonds to the size of the type, or + // a splat of the type. + if (static_cast(data.size() * CHAR_BIT) != + (type.getNumElements() * elementType.getIntOrFloatBitWidth())) { + p.emitError(loc) << "elements hex data size is invalid for provided type: " + << type; + return nullptr; + } + + return DenseElementsAttr::getFromRawBuffer( + type, ArrayRef(data.data(), data.size()), /*isSplatBuffer=*/false); +} + ParseResult TensorLiteralParser::parseElement() { switch (p.getToken().getKind()) { // Parse a boolean element. @@ -2125,7 +2188,7 @@ // Parse the literal data. TensorLiteralParser literalParser(*this); - if (literalParser.parse()) + if (literalParser.parse(/*allowHex=*/true)) return nullptr; if (parseToken(Token::greater, "expected '>'")) @@ -2170,19 +2233,20 @@ if (parseToken(Token::less, "Expected '<' after 'sparse'")) return nullptr; - /// Parse indices + /// Parse the indices. We don't allow hex values here as we may need to use + /// the inferred shape. auto indicesLoc = getToken().getLoc(); TensorLiteralParser indiceParser(*this); - if (indiceParser.parse()) + if (indiceParser.parse(/*allowHex=*/false)) return nullptr; if (parseToken(Token::comma, "expected ','")) return nullptr; - /// Parse values. + /// Parse the values. auto valuesLoc = getToken().getLoc(); TensorLiteralParser valuesParser(*this); - if (valuesParser.parse()) + if (valuesParser.parse(/*allowHex=*/true)) return nullptr; if (parseToken(Token::greater, "expected '>'")) diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file -mlir-print-elementsattrs-with-hex-if-larger=1 | FileCheck %s --check-prefix=HEX +// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s + +// HEX: dense<"0x00000000000024400000000000001440"> : tensor<2xf64> +"foo.op"() {dense.attr = dense<[10.0, 5.0]> : tensor<2xf64>} : () -> () + +// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64> +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> () + +// ----- + +// expected-error@+1 {{elements hex string should start with '0x'}} +"foo.op"() {dense.attr = dense<"00000000000024400000000000001440"> : tensor<2xf64>} : () -> () + +// ----- + +// expected-error@+1 {{elements hex string only contains hex digits}} +"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144X"> : tensor<2xf64>} : () -> () + +// ----- + +// expected-error@+1 {{expected floating-point or integer element type, got '!unknown<"">'}} +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> () + +// ----- + +// expected-error@+1 {{elements hex data size is invalid for provided type}} +"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> () diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -703,14 +703,14 @@ func @elementsattr_malformed_opaque1() -> () { ^bb0: - "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{opaque string only contains hex digits}} + "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string only contains hex digits}} } // ----- func @elementsattr_malformed_opaque2() -> () { ^bb0: - "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{opaque string should start with '0x'}} + "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string should start with '0x'}} } // -----