diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -577,6 +577,9 @@ /// Parse a quoted string token if present. virtual ParseResult parseOptionalString(std::string *string) = 0; + /// Parses a Base64 encoded string of bytes. + virtual ParseResult parseBase64Bytes(std::vector *bytes) = 0; + /// Parse a `(` token. virtual ParseResult parseLParen() = 0; diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -13,6 +13,7 @@ #include "mlir/AsmParser/AsmParserState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/Support/Base64.h" namespace mlir { namespace detail { @@ -245,6 +246,28 @@ return success(); } + /// Parses a Base64 encoded string of bytes. + ParseResult parseBase64Bytes(std::vector *bytes) override { + auto loc = getCurrentLocation(); + if (!parser.getToken().is(Token::string)) + return emitError(loc, "expected string"); + + if (bytes) { + // decodeBase64 doesn't modify its input so we can use the token spelling + // and just slice off the quotes/whitespaces if there are any. Whitespace + // and quotes cannot appear as part of a (standard) base64 encoded string, + // so this is safe to do. + StringRef b64QuotedString = parser.getTokenSpelling(); + StringRef b64String = + b64QuotedString.ltrim("\" \t\n\v\f\r").rtrim("\" \t\n\v\f\r"); + if (auto err = llvm::decodeBase64(b64String, *bytes)) + return emitError(loc, toString(std::move(err))); + } + + parser.consumeToken(); + return success(); + } + /// Parse a floating point value from the stream. ParseResult parseFloat(double &result) override { bool isNegative = parser.consumeIf(Token::minus); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1185,6 +1185,13 @@ return } +// CHECK-LABEL: func @parse_base64_test +func.func @parse_base64_test() { + // CHECK: test.parse_b64 "hello world" + test.parse_b64 "aGVsbG8gd29ybGQ=" + return +} + // CHECK-LABEL: func @"\22_string_symbol_reference\22" func.func @"\"_string_symbol_reference\""() { // CHECK: ref = @"\22_string_symbol_reference\22" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -862,6 +862,21 @@ void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } +ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, + OperationState &result) { + std::vector bytes; + if (parser.parseBase64Bytes(&bytes)) + return failure(); + result.addAttribute("b64", parser.getBuilder().getStringAttr( + StringRef(&bytes.front(), bytes.size()))); + return success(); +} + +void ParseB64BytesOp::print(OpAsmPrinter &p) { + // Don't print the base64 version to check that we decoded it correctly. + p << " \"" << getB64() << "\""; +} + //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1766,6 +1766,11 @@ let hasCustomAssemblyFormat = 1; } +def ParseB64BytesOp : TEST_Op<"parse_b64"> { + let arguments = (ins StrAttr:$b64); + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Test region argument list parsing.