Skip to content

Re-apply ea9ac3519c13 #7517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/Support/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ template <class InputBytes> std::string encodeBase64(InputBytes const &Bytes) {
return Buffer;
}

llvm::Error decodeBase64(llvm::StringRef Input, std::vector<char> &Output);

// General-purpose Base64 encoder/decoder.
// TODO update WinCOFFObjectWriter.cpp to use this library.
class Base64 {
Expand Down
83 changes: 83 additions & 0 deletions llvm/lib/Support/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,93 @@
//
//===----------------------------------------------------------------------===//

#define INVALID_BASE64_BYTE 64
#include "llvm/Support/Base64.h"

#include <memory>

static char decodeBase64Byte(uint8_t Ch) {
constexpr char Inv = INVALID_BASE64_BYTE;
static const char DecodeTable[] = {
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
Inv, Inv, Inv, 62, Inv, Inv, Inv, 63, // ...+.../
52, 53, 54, 55, 56, 57, 58, 59, // 01234567
60, 61, Inv, Inv, Inv, 0, Inv, Inv, // 89...=..
Inv, 0, 1, 2, 3, 4, 5, 6, // .ABCDEFG
7, 8, 9, 10, 11, 12, 13, 14, // HIJKLMNO
15, 16, 17, 18, 19, 20, 21, 22, // PQRSTUVW
23, 24, 25, Inv, Inv, Inv, Inv, Inv, // XYZ.....
Inv, 26, 27, 28, 29, 30, 31, 32, // .abcdefg
33, 34, 35, 36, 37, 38, 39, 40, // hijklmno
41, 42, 43, 44, 45, 46, 47, 48, // pqrstuvw
49, 50, 51 // xyz.....
};
if (Ch >= sizeof(DecodeTable))
return Inv;
return DecodeTable[Ch];
}

llvm::Error llvm::decodeBase64(llvm::StringRef Input,
std::vector<char> &Output) {
constexpr char Base64InvalidByte = INVALID_BASE64_BYTE;
// Invalid table value with short name to fit in the table init below. The
// invalid value is 64 since valid base64 values are 0 - 63.
Output.clear();
const uint64_t InputLength = Input.size();
if (InputLength == 0)
return Error::success();
// Make sure we have a valid input string length which must be a multiple
// of 4.
if ((InputLength % 4) != 0)
return createStringError(std::errc::illegal_byte_sequence,
"Base64 encoded strings must be a multiple of 4 "
"bytes in length");
const uint64_t FirstValidEqualIdx = InputLength - 2;
char Hex64Bytes[4];
for (uint64_t Idx = 0; Idx < InputLength; Idx += 4) {
for (uint64_t ByteOffset = 0; ByteOffset < 4; ++ByteOffset) {
const uint64_t ByteIdx = Idx + ByteOffset;
const char Byte = Input[ByteIdx];
const char DecodedByte = decodeBase64Byte(Byte);
bool Illegal = DecodedByte == Base64InvalidByte;
if (!Illegal && Byte == '=') {
if (ByteIdx < FirstValidEqualIdx) {
// We have an '=' in the middle of the string which is invalid, only
// the last two characters can be '=' characters.
Illegal = true;
} else if (ByteIdx == FirstValidEqualIdx && Input[ByteIdx + 1] != '=') {
// We have an equal second to last from the end and the last character
// is not also an equal, so the '=' character is invalid
Illegal = true;
}
}
if (Illegal)
return createStringError(
std::errc::illegal_byte_sequence,
"Invalid Base64 character %#2.2x at index %" PRIu64, Byte, ByteIdx);
Hex64Bytes[ByteOffset] = DecodedByte;
}
// Now we have 6 bits of 3 bytes in value in each of the Hex64Bytes bytes.
// Extract the right bytes into the Output buffer.
Output.push_back((Hex64Bytes[0] << 2) + ((Hex64Bytes[1] >> 4) & 0x03));
Output.push_back((Hex64Bytes[1] << 4) + ((Hex64Bytes[2] >> 2) & 0x0f));
Output.push_back((Hex64Bytes[2] << 6) + (Hex64Bytes[3] & 0x3f));
}
// If we had valid trailing '=' characters strip the right number of bytes
// from the end of the output buffer. We already know that the Input length
// it a multiple of 4 and is not zero, so direct character access is safe.
if (Input.back() == '=') {
Output.pop_back();
if (Input[InputLength - 2] == '=')
Output.pop_back();
}
return Error::success();
}

using namespace llvm;

namespace {
Expand Down
49 changes: 49 additions & 0 deletions llvm/unittests/Support/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ void TestBase64(StringRef Input, StringRef Final) {
EXPECT_EQ(Res, Final);
}

void TestBase64Decode(StringRef Input, StringRef Expected,
StringRef ExpectedErrorMessage = {}) {
std::vector<char> DecodedBytes;
if (ExpectedErrorMessage.empty()) {
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), Succeeded());
EXPECT_EQ(llvm::ArrayRef<char>(DecodedBytes),
llvm::ArrayRef<char>(Expected.data(), Expected.size()));
} else {
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes),
FailedWithMessage(ExpectedErrorMessage));
}
}

char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
0x00, 0x08, (char)0xff, (char)0xee};

Expand Down Expand Up @@ -53,6 +66,42 @@ TEST(Base64Test, Base64) {
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4=");
}

TEST(Base64Test, DecodeBase64) {
std::vector<llvm::StringRef> Outputs = {"", "f", "fo", "foo",
"foob", "fooba", "foobar"};
Outputs.push_back(
llvm::StringRef(NonPrintableVector, sizeof(NonPrintableVector)));

Outputs.push_back(llvm::StringRef(LargeVector, sizeof(LargeVector)));
// Make sure we can encode and decode any byte.
std::vector<char> AllChars;
for (int Ch = INT8_MIN; Ch <= INT8_MAX; ++Ch)
AllChars.push_back(Ch);
Outputs.push_back(llvm::StringRef(AllChars.data(), AllChars.size()));

for (const auto &Output : Outputs) {
// We trust that encoding is working after running the Base64Test::Base64()
// test function above, so we can use it to encode the string and verify we
// can decode it correctly.
auto Input = encodeBase64(Output);
TestBase64Decode(Input, Output);
}
struct ErrorInfo {
llvm::StringRef Input;
llvm::StringRef ErrorMessage;
};
std::vector<ErrorInfo> ErrorInfos = {
{"f", "Base64 encoded strings must be a multiple of 4 bytes in length"},
{"=abc", "Invalid Base64 character 0x3d at index 0"},
{"a=bc", "Invalid Base64 character 0x3d at index 1"},
{"ab=c", "Invalid Base64 character 0x3d at index 2"},
{"fun!", "Invalid Base64 character 0x21 at index 3"},
};

for (const auto &EI : ErrorInfos)
TestBase64Decode(EI.Input, "", EI.ErrorMessage);
}

TEST(Base64Test, RoundTrip) {
using byte = unsigned char;
const byte Arr0[] = {0x1};
Expand Down