diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -410,22 +410,16 @@ // TensorLiteralParser //===----------------------------------------------------------------------===// -/// Parse elements values stored within a hex etring. On success, the values are +/// Parse elements values stored within a hex string. On success, the values are /// stored into 'result'. static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, std::string &result) { - std::string val = tok.getStringValue(); - if (val.size() < 2 || val[0] != '0' || val[1] != 'x') - return parser.emitError(tok.getLoc(), - "elements hex string should start with '0x'"); - - StringRef hexValues = StringRef(val).drop_front(2); - if (!llvm::all_of(hexValues, llvm::isHexDigit)) - return parser.emitError(tok.getLoc(), - "elements hex string only contains hex digits"); - - result = llvm::fromHex(hexValues); - return success(); + if (Optional value = tok.getHexStringValue()) { + result = std::move(*value); + return success(); + } + return parser.emitError( + tok.getLoc(), "expected string containing hex digits starting with `0x`"); } namespace { diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -91,6 +91,11 @@ /// removing the quote characters and unescaping the contents of the string. std::string getStringValue() const; + /// Given a token containing a hex string literal, return its value or None if + /// the token does not contain a valid hex string. A hex string literal is a + /// string starting with `0x` and only containing hex digits. + Optional getHexStringValue() const; + /// Given a token containing a symbol reference, return the unescaped string /// value. std::string getSymbolReference() const; diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -124,6 +124,21 @@ return result; } +/// Given a token containing a hex string literal, return its value or None if +/// the token does not contain a valid hex string. +Optional Token::getHexStringValue() const { + assert(getKind() == string); + + // Get the internal string data, without the quotes. + StringRef bytes = getSpelling().drop_front().drop_back(); + + // Try to extract the binary data from the hex string. + std::string hex; + if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex)) + return llvm::None; + return hex; +} + /// Given a token containing a symbol reference, return the unescaped string /// value. std::string Token::getSymbolReference() const { diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -15,12 +15,12 @@ // ----- -// expected-error@+1 {{elements hex string should start with '0x'}} +// expected-error@+1 {{expected string containing hex digits starting with `0x`}} "foo.op"() {dense.attr = dense<"00000000000024400000000000001440"> : tensor<2xf64>} : () -> () // ----- -// expected-error@+1 {{elements hex string only contains hex digits}} +// expected-error@+1 {{expected string containing hex digits starting with `0x`}} "foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144X"> : tensor<2xf64>} : () -> () // ----- diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -718,14 +718,14 @@ func @elementsattr_malformed_opaque1() -> () { ^bb0: - "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string only contains hex digits}} + "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} } // ----- func @elementsattr_malformed_opaque2() -> () { ^bb0: - "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string should start with '0x'}} + "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} } // -----