Skip to content

Commit 8bc1e87

Browse files
Re-apply ea9ac35 (#7517)
ea9ac35 re-adds the llvm::decodeBase64 function, but on our branch it seems to have been lost in the merge 8b549f6. This change adds it back to fix build issues introduced by upstream changes that rely on this function. This should fix #7515.
1 parent d8fd9bc commit 8bc1e87

File tree

3 files changed

+134
-0
lines changed

3 files changed

+134
-0
lines changed

llvm/include/llvm/Support/Base64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ template <class InputBytes> std::string encodeBase64(InputBytes const &Bytes) {
5757
return Buffer;
5858
}
5959

60+
llvm::Error decodeBase64(llvm::StringRef Input, std::vector<char> &Output);
61+
6062
// General-purpose Base64 encoder/decoder.
6163
// TODO update WinCOFFObjectWriter.cpp to use this library.
6264
class Base64 {

llvm/lib/Support/Base64.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,93 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#define INVALID_BASE64_BYTE 64
910
#include "llvm/Support/Base64.h"
1011

1112
#include <memory>
1213

14+
static char decodeBase64Byte(uint8_t Ch) {
15+
constexpr char Inv = INVALID_BASE64_BYTE;
16+
static const char DecodeTable[] = {
17+
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
18+
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
19+
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
20+
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
21+
Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
22+
Inv, Inv, Inv, 62, Inv, Inv, Inv, 63, // ...+.../
23+
52, 53, 54, 55, 56, 57, 58, 59, // 01234567
24+
60, 61, Inv, Inv, Inv, 0, Inv, Inv, // 89...=..
25+
Inv, 0, 1, 2, 3, 4, 5, 6, // .ABCDEFG
26+
7, 8, 9, 10, 11, 12, 13, 14, // HIJKLMNO
27+
15, 16, 17, 18, 19, 20, 21, 22, // PQRSTUVW
28+
23, 24, 25, Inv, Inv, Inv, Inv, Inv, // XYZ.....
29+
Inv, 26, 27, 28, 29, 30, 31, 32, // .abcdefg
30+
33, 34, 35, 36, 37, 38, 39, 40, // hijklmno
31+
41, 42, 43, 44, 45, 46, 47, 48, // pqrstuvw
32+
49, 50, 51 // xyz.....
33+
};
34+
if (Ch >= sizeof(DecodeTable))
35+
return Inv;
36+
return DecodeTable[Ch];
37+
}
38+
39+
llvm::Error llvm::decodeBase64(llvm::StringRef Input,
40+
std::vector<char> &Output) {
41+
constexpr char Base64InvalidByte = INVALID_BASE64_BYTE;
42+
// Invalid table value with short name to fit in the table init below. The
43+
// invalid value is 64 since valid base64 values are 0 - 63.
44+
Output.clear();
45+
const uint64_t InputLength = Input.size();
46+
if (InputLength == 0)
47+
return Error::success();
48+
// Make sure we have a valid input string length which must be a multiple
49+
// of 4.
50+
if ((InputLength % 4) != 0)
51+
return createStringError(std::errc::illegal_byte_sequence,
52+
"Base64 encoded strings must be a multiple of 4 "
53+
"bytes in length");
54+
const uint64_t FirstValidEqualIdx = InputLength - 2;
55+
char Hex64Bytes[4];
56+
for (uint64_t Idx = 0; Idx < InputLength; Idx += 4) {
57+
for (uint64_t ByteOffset = 0; ByteOffset < 4; ++ByteOffset) {
58+
const uint64_t ByteIdx = Idx + ByteOffset;
59+
const char Byte = Input[ByteIdx];
60+
const char DecodedByte = decodeBase64Byte(Byte);
61+
bool Illegal = DecodedByte == Base64InvalidByte;
62+
if (!Illegal && Byte == '=') {
63+
if (ByteIdx < FirstValidEqualIdx) {
64+
// We have an '=' in the middle of the string which is invalid, only
65+
// the last two characters can be '=' characters.
66+
Illegal = true;
67+
} else if (ByteIdx == FirstValidEqualIdx && Input[ByteIdx + 1] != '=') {
68+
// We have an equal second to last from the end and the last character
69+
// is not also an equal, so the '=' character is invalid
70+
Illegal = true;
71+
}
72+
}
73+
if (Illegal)
74+
return createStringError(
75+
std::errc::illegal_byte_sequence,
76+
"Invalid Base64 character %#2.2x at index %" PRIu64, Byte, ByteIdx);
77+
Hex64Bytes[ByteOffset] = DecodedByte;
78+
}
79+
// Now we have 6 bits of 3 bytes in value in each of the Hex64Bytes bytes.
80+
// Extract the right bytes into the Output buffer.
81+
Output.push_back((Hex64Bytes[0] << 2) + ((Hex64Bytes[1] >> 4) & 0x03));
82+
Output.push_back((Hex64Bytes[1] << 4) + ((Hex64Bytes[2] >> 2) & 0x0f));
83+
Output.push_back((Hex64Bytes[2] << 6) + (Hex64Bytes[3] & 0x3f));
84+
}
85+
// If we had valid trailing '=' characters strip the right number of bytes
86+
// from the end of the output buffer. We already know that the Input length
87+
// it a multiple of 4 and is not zero, so direct character access is safe.
88+
if (Input.back() == '=') {
89+
Output.pop_back();
90+
if (Input[InputLength - 2] == '=')
91+
Output.pop_back();
92+
}
93+
return Error::success();
94+
}
95+
1396
using namespace llvm;
1497

1598
namespace {

llvm/unittests/Support/Base64Test.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ void TestBase64(StringRef Input, StringRef Final) {
2424
EXPECT_EQ(Res, Final);
2525
}
2626

27+
void TestBase64Decode(StringRef Input, StringRef Expected,
28+
StringRef ExpectedErrorMessage = {}) {
29+
std::vector<char> DecodedBytes;
30+
if (ExpectedErrorMessage.empty()) {
31+
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), Succeeded());
32+
EXPECT_EQ(llvm::ArrayRef<char>(DecodedBytes),
33+
llvm::ArrayRef<char>(Expected.data(), Expected.size()));
34+
} else {
35+
ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes),
36+
FailedWithMessage(ExpectedErrorMessage));
37+
}
38+
}
39+
2740
char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
2841
0x00, 0x08, (char)0xff, (char)0xee};
2942

@@ -53,6 +66,42 @@ TEST(Base64Test, Base64) {
5366
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4=");
5467
}
5568

69+
TEST(Base64Test, DecodeBase64) {
70+
std::vector<llvm::StringRef> Outputs = {"", "f", "fo", "foo",
71+
"foob", "fooba", "foobar"};
72+
Outputs.push_back(
73+
llvm::StringRef(NonPrintableVector, sizeof(NonPrintableVector)));
74+
75+
Outputs.push_back(llvm::StringRef(LargeVector, sizeof(LargeVector)));
76+
// Make sure we can encode and decode any byte.
77+
std::vector<char> AllChars;
78+
for (int Ch = INT8_MIN; Ch <= INT8_MAX; ++Ch)
79+
AllChars.push_back(Ch);
80+
Outputs.push_back(llvm::StringRef(AllChars.data(), AllChars.size()));
81+
82+
for (const auto &Output : Outputs) {
83+
// We trust that encoding is working after running the Base64Test::Base64()
84+
// test function above, so we can use it to encode the string and verify we
85+
// can decode it correctly.
86+
auto Input = encodeBase64(Output);
87+
TestBase64Decode(Input, Output);
88+
}
89+
struct ErrorInfo {
90+
llvm::StringRef Input;
91+
llvm::StringRef ErrorMessage;
92+
};
93+
std::vector<ErrorInfo> ErrorInfos = {
94+
{"f", "Base64 encoded strings must be a multiple of 4 bytes in length"},
95+
{"=abc", "Invalid Base64 character 0x3d at index 0"},
96+
{"a=bc", "Invalid Base64 character 0x3d at index 1"},
97+
{"ab=c", "Invalid Base64 character 0x3d at index 2"},
98+
{"fun!", "Invalid Base64 character 0x21 at index 3"},
99+
};
100+
101+
for (const auto &EI : ErrorInfos)
102+
TestBase64Decode(EI.Input, "", EI.ErrorMessage);
103+
}
104+
56105
TEST(Base64Test, RoundTrip) {
57106
using byte = unsigned char;
58107
const byte Arr0[] = {0x1};

0 commit comments

Comments
 (0)