diff --git a/llvm/include/llvm/Support/Base64.h b/llvm/include/llvm/Support/Base64.h --- a/llvm/include/llvm/Support/Base64.h +++ b/llvm/include/llvm/Support/Base64.h @@ -13,6 +13,7 @@ #ifndef LLVM_SUPPORT_BASE64_H #define LLVM_SUPPORT_BASE64_H +#include "llvm/Support/Error.h" #include #include @@ -52,6 +53,85 @@ return Buffer; } +template +llvm::Error decodeBase64(llvm::StringRef Input, OutputBytes &Output) { + // Invalid table value with short name to fit in the table init below. The + // invalid value is 64 since valid base64 values are 0 - 63. + constexpr char Inv = 64; + static char DecodeTable[] = { + Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........ + Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........ + Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........ + Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........ + Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........ + Inv, Inv, Inv, 62, Inv, Inv, Inv, 63, // ...+.../ + 52, 53, 54, 55, 56, 57, 58, 59, // 01234567 + 60, 61, Inv, Inv, Inv, 0, Inv, Inv, // 89...=.. + Inv, 0, 1, 2, 3, 4, 5, 6, // .ABCDEFG + 7, 8, 9, 10, 11, 12, 13, 14, // HIJKLMNO + 15, 16, 17, 18, 19, 20, 21, 22, // PQRSTUVW + 23, 24, 25, Inv, Inv, Inv, Inv, Inv, // XYZ..... + Inv, 26, 27, 28, 29, 30, 31, 32, // .abcdefg + 33, 34, 35, 36, 37, 38, 39, 40, // hijklmno + 41, 42, 43, 44, 45, 46, 47, 48, // pqrstuvw + 49, 50, 51 // xyz..... + }; + auto decodeBase64Byte = [](uint8_t Ch) -> char { + if (Ch >= sizeof(DecodeTable)) + return Inv; + return DecodeTable[Ch]; + }; + Output.clear(); + const size_t BytesLength = Input.size(); + if (BytesLength == 0) + return Error::success(); + // Make sure we have a valid input string length which must be a multiple + // of 4. + if ((BytesLength % 4) != 0) + return createStringError(std::errc::illegal_byte_sequence, + "Base64 encoded strings must be a multiple of 4 " + "bytes in length"); + // Check for '=' characters in the string. There can be at most two '=' + // characters at the end of the string and none in the middle. + uint32_t NumBytesToStrip = 0; + uint64_t EqualPos = Input.find('='); + if (EqualPos != StringRef::npos) { + NumBytesToStrip = Input.size() - EqualPos; + // If we have more than two bytes to strip, then we had an '=' character + // in the middle of the string. If we have two bytes to strip, also verify + // that the last character is also '='. + if (NumBytesToStrip > 2 || + ((NumBytesToStrip == 2) && Input.back() != '=')) { + return createStringError( + std::errc::illegal_byte_sequence, + "Invalid Base64 character %#2.2x at index %" PRIu64, Input[EqualPos], + EqualPos); + } + } + char Hex64Bytes[4]; + for (uint64_t Idx = 0; Idx < BytesLength; Idx += 4) { + for (uint64_t ByteIdx = 0; ByteIdx < 4; ++ByteIdx) { + const char Byte = Input[Idx + ByteIdx]; + Hex64Bytes[ByteIdx] = decodeBase64Byte(Byte); + if (Hex64Bytes[ByteIdx] == Inv) + return createStringError( + std::errc::illegal_byte_sequence, + "Invalid Base64 character %#2.2x at index %" PRIu64, Byte, + Idx + ByteIdx); + } + // Now we have 6 bits of 3 bytes in value in each of the Hex64Bytes bytes. + // Extract the right bytes into the Output buffer. + Output.push_back((Hex64Bytes[0] << 2) + ((Hex64Bytes[1] >> 4) & 0x03)); + Output.push_back((Hex64Bytes[1] << 4) + ((Hex64Bytes[2] >> 2) & 0x0f)); + Output.push_back((Hex64Bytes[2] << 6) + (Hex64Bytes[3] & 0x3f)); + } + // If we had valid trailing '=' characters strip the right number of bytes + // from the end of the output buffer. + while (NumBytesToStrip-- > 0) + Output.pop_back(); + return Error::success(); +} + } // end namespace llvm #endif diff --git a/llvm/unittests/Support/Base64Test.cpp b/llvm/unittests/Support/Base64Test.cpp --- a/llvm/unittests/Support/Base64Test.cpp +++ b/llvm/unittests/Support/Base64Test.cpp @@ -13,6 +13,7 @@ #include "llvm/Support/Base64.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" using namespace llvm; @@ -24,6 +25,28 @@ EXPECT_EQ(Res, Final); } +void TestBase64Decode(StringRef Input, StringRef Expected, + StringRef ExpectedErrorMessage = {}) { + std::vector DecodedBytes; + if (ExpectedErrorMessage.empty()) { + ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), Succeeded()); + EXPECT_EQ(llvm::ArrayRef(DecodedBytes), + llvm::ArrayRef(Expected.data(), Expected.size())); + } else { + ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), + FailedWithMessage(ExpectedErrorMessage)); + } +} + +char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46, + 0x00, 0x08, (char)0xff, (char)0xee}; + +char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b, + 0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f, + 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f, + 0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61, + 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e}; + } // namespace TEST(Base64Test, Base64) { @@ -37,16 +60,45 @@ TestBase64("foobar", "Zm9vYmFy"); // With non-printable values. - char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46, - 0x00, 0x08, (char)0xff, (char)0xee}; TestBase64({NonPrintableVector, sizeof(NonPrintableVector)}, "AAAARgAI/+4="); // Large test case - char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b, - 0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f, - 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f, - 0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61, - 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e}; TestBase64({LargeVector, sizeof(LargeVector)}, "VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4="); } + +TEST(Base64Test, DecodeBase64) { + std::vector Outputs = {"", "f", "fo", "foo", + "foob", "fooba", "foobar"}; + Outputs.push_back( + llvm::StringRef(NonPrintableVector, sizeof(NonPrintableVector))); + + Outputs.push_back(llvm::StringRef(LargeVector, sizeof(LargeVector))); + // Make sure we can encode and decode any byte. + std::vector AllChars; + for (int Ch = INT8_MIN; Ch <= INT8_MAX; ++Ch) + AllChars.push_back(Ch); + Outputs.push_back(llvm::StringRef(AllChars.data(), AllChars.size())); + + for (const auto &Output : Outputs) { + // We trust that encoding is working after running the Base64Test::Base64() + // test function above, so we can use it to encode the string and verify we + // can decode it correctly. + auto Input = encodeBase64(Output); + TestBase64Decode(Input, Output); + } + struct ErrorInfo { + llvm::StringRef Input; + llvm::StringRef ErrorMessage; + }; + std::vector ErrorInfos = { + {"f", "Base64 encoded strings must be a multiple of 4 bytes in length"}, + {"=abc", "Invalid Base64 character 0x3d at index 0"}, + {"a=bc", "Invalid Base64 character 0x3d at index 1"}, + {"ab=c", "Invalid Base64 character 0x3d at index 2"}, + {"fun!", "Invalid Base64 character 0x21 at index 3"}, + }; + + for (const auto &EI : ErrorInfos) + TestBase64Decode(EI.Input, "", EI.ErrorMessage); +}