// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include namespace zen { /** The table used to encode a 6 bit value as an ascii character */ static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; /** The table used to convert an ascii character into a 6 bit value */ static const uint8_t DecodingAlphabet[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, // 0x20-0x2f 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x30-0x3f 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, // 0x40-0x4f 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x50-0x5f 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, // 0x60-0x6f 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x70-0x7f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x80-0x8f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x90-0x9f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xa0-0xaf 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xb0-0xbf 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xc0-0xcf 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xd0-0xdf 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff }; template uint32_t Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest) { CharType* EncodedBytes = Dest; // Loop through the buffer converting 3 bytes of binary data at a time while (Length >= 3) { uint8_t A = *Source++; uint8_t B = *Source++; uint8_t C = *Source++; Length -= 3; // The algorithm takes 24 bits of data (3 bytes) and breaks it into 4 6bit chunks represented as ascii uint32_t ByteTriplet = A << 16 | B << 8 | C; // Use the 6bit block to find the representation ascii character for it EncodedBytes[3] = EncodingAlphabet[ByteTriplet & 0x3F]; ByteTriplet >>= 6; EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F]; ByteTriplet >>= 6; EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F]; ByteTriplet >>= 6; EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F]; // Now we can append this buffer to our destination string EncodedBytes += 4; } // Since this algorithm operates on blocks, we may need to pad the last chunks if (Length > 0) { uint8_t A = *Source++; uint8_t B = 0; uint8_t C = 0; // Grab the second character if it is a 2 uint8_t finish if (Length == 2) { B = *Source; } uint32_t ByteTriplet = A << 16 | B << 8 | C; // Pad with = to make a 4 uint8_t chunk EncodedBytes[3] = '='; ByteTriplet >>= 6; // If there's only one 1 uint8_t left in the source, then you need 2 pad chars if (Length == 1) { EncodedBytes[2] = '='; } else { EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F]; } // Now encode the remaining bits the same way ByteTriplet >>= 6; EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F]; ByteTriplet >>= 6; EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F]; EncodedBytes += 4; } // Add a null terminator *EncodedBytes = 0; return uint32_t(EncodedBytes - Dest); } template uint32_t Base64::Encode(const uint8_t* Source, uint32_t Length, char* Dest); template uint32_t Base64::Encode(const uint8_t* Source, uint32_t Length, wchar_t* Dest); template bool Base64::Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength) { // Length must be a multiple of 4 if (Length % 4 != 0) { OutLength = 0; return false; } uint8_t* DecodedBytes = Dest; // Process 4 encoded characters at a time, producing 3 decoded bytes while (Length > 0) { // Count padding characters at the end uint32_t PadCount = 0; if (Source[3] == '=') { PadCount++; if (Source[2] == '=') { PadCount++; } } // Look up each character in the decoding table uint8_t A = DecodingAlphabet[static_cast(Source[0])]; uint8_t B = DecodingAlphabet[static_cast(Source[1])]; uint8_t C = (PadCount >= 2) ? 0 : DecodingAlphabet[static_cast(Source[2])]; uint8_t D = (PadCount >= 1) ? 0 : DecodingAlphabet[static_cast(Source[3])]; // Check for invalid characters (0xFF means not in the base64 alphabet) if (A == 0xFF || B == 0xFF || C == 0xFF || D == 0xFF) { OutLength = 0; return false; } // Reconstruct the 24-bit value from 4 6-bit chunks uint32_t ByteTriplet = (A << 18) | (B << 12) | (C << 6) | D; // Extract the 3 bytes *DecodedBytes++ = static_cast(ByteTriplet >> 16); if (PadCount < 2) { *DecodedBytes++ = static_cast((ByteTriplet >> 8) & 0xFF); } if (PadCount < 1) { *DecodedBytes++ = static_cast(ByteTriplet & 0xFF); } Source += 4; Length -= 4; } OutLength = uint32_t(DecodedBytes - Dest); return true; } template bool Base64::Decode(const char* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); template bool Base64::Decode(const wchar_t* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); ////////////////////////////////////////////////////////////////////////// // // Testing related code follows... // #if ZEN_WITH_TESTS using namespace std::string_literals; TEST_CASE("Base64") { auto EncodeString = [](std::string_view Input) -> std::string { std::string Result; Result.resize(Base64::GetEncodedDataSize(uint32_t(Input.size()))); Base64::Encode(reinterpret_cast(Input.data()), uint32_t(Input.size()), Result.data()); return Result; }; auto DecodeString = [](std::string_view Input) -> std::string { std::string Result; Result.resize(Base64::GetMaxDecodedDataSize(uint32_t(Input.size()))); uint32_t DecodedLength = 0; bool Success = Base64::Decode(Input.data(), uint32_t(Input.size()), reinterpret_cast(Result.data()), DecodedLength); CHECK(Success); Result.resize(DecodedLength); return Result; }; SUBCASE("Encode") { CHECK(EncodeString("") == ""s); CHECK(EncodeString("f") == "Zg=="s); CHECK(EncodeString("fo") == "Zm8="s); CHECK(EncodeString("foo") == "Zm9v"s); CHECK(EncodeString("foob") == "Zm9vYg=="s); CHECK(EncodeString("fooba") == "Zm9vYmE="s); CHECK(EncodeString("foobar") == "Zm9vYmFy"s); } SUBCASE("Decode") { CHECK(DecodeString("") == ""s); CHECK(DecodeString("Zg==") == "f"s); CHECK(DecodeString("Zm8=") == "fo"s); CHECK(DecodeString("Zm9v") == "foo"s); CHECK(DecodeString("Zm9vYg==") == "foob"s); CHECK(DecodeString("Zm9vYmE=") == "fooba"s); CHECK(DecodeString("Zm9vYmFy") == "foobar"s); } SUBCASE("RoundTrip") { auto RoundTrip = [&](const std::string& Input) { std::string Encoded = EncodeString(Input); std::string Decoded = DecodeString(Encoded); CHECK(Decoded == Input); }; RoundTrip("Hello, World!"); RoundTrip("Base64 encoding test with various lengths"); RoundTrip("A"); RoundTrip("AB"); RoundTrip("ABC"); RoundTrip("ABCD"); RoundTrip("\x00\x01\x02\xff\xfe\xfd"s); } SUBCASE("BinaryRoundTrip") { // Test with all byte values 0-255 uint8_t AllBytes[256]; for (int i = 0; i < 256; ++i) { AllBytes[i] = static_cast(i); } char Encoded[Base64::GetEncodedDataSize(256) + 1]; Base64::Encode(AllBytes, 256, Encoded); uint8_t Decoded[256]; uint32_t DecodedLength = 0; bool Success = Base64::Decode(Encoded, uint32_t(strlen(Encoded)), Decoded, DecodedLength); CHECK(Success); CHECK(DecodedLength == 256); CHECK(memcmp(AllBytes, Decoded, 256) == 0); } SUBCASE("DecodeInvalidInput") { uint8_t Dest[64]; uint32_t OutLength = 0; // Length not a multiple of 4 CHECK_FALSE(Base64::Decode("abc", 3u, Dest, OutLength)); // Invalid character CHECK_FALSE(Base64::Decode("ab!d", 4u, Dest, OutLength)); } SUBCASE("EncodedDataSize") { CHECK(Base64::GetEncodedDataSize(0) == 0); CHECK(Base64::GetEncodedDataSize(1) == 4); CHECK(Base64::GetEncodedDataSize(2) == 4); CHECK(Base64::GetEncodedDataSize(3) == 4); CHECK(Base64::GetEncodedDataSize(4) == 8); CHECK(Base64::GetEncodedDataSize(5) == 8); CHECK(Base64::GetEncodedDataSize(6) == 8); } SUBCASE("MaxDecodedDataSize") { CHECK(Base64::GetMaxDecodedDataSize(0) == 0); CHECK(Base64::GetMaxDecodedDataSize(4) == 3); CHECK(Base64::GetMaxDecodedDataSize(8) == 6); CHECK(Base64::GetMaxDecodedDataSize(12) == 9); } } #endif } // namespace zen