aboutsummaryrefslogtreecommitdiff
path: root/src/zencore/base64.cpp
blob: fdf5f2d66a8e98452981c7702d7f11b31d220cce (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
// Copyright Epic Games, Inc. All Rights Reserved.

#include <zencore/base64.h>
#include <zencore/string.h>
#include <zencore/testing.h>

#include <string>

namespace zen {

/** The table used to encode a 6 bit value as an ascii character */
static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
											 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
											 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
											 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};

/** The table used to convert an ascii character into a 6 bit value */
static const uint8_t DecodingAlphabet[256] = {
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x00-0x0f
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x10-0x1f
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F,	 // 0x20-0x2f
	0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x30-0x3f
	0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,	 // 0x40-0x4f
	0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x50-0x5f
	0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,	 // 0x60-0x6f
	0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x70-0x7f
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x80-0x8f
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0x90-0x9f
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0xa0-0xaf
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0xb0-0xbf
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0xc0-0xcf
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0xd0-0xdf
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,	 // 0xe0-0xef
	0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF	 // 0xf0-0xff
};

template<typename CharType>
uint32_t
Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest)
{
	CharType* EncodedBytes = Dest;

	// Loop through the buffer converting 3 bytes of binary data at a time
	while (Length >= 3)
	{
		uint8_t A = *Source++;
		uint8_t B = *Source++;
		uint8_t C = *Source++;
		Length -= 3;

		// The algorithm takes 24 bits of data (3 bytes) and breaks it into 4 6bit chunks represented as ascii
		uint32_t ByteTriplet = A << 16 | B << 8 | C;

		// Use the 6bit block to find the representation ascii character for it
		EncodedBytes[3] = EncodingAlphabet[ByteTriplet & 0x3F];
		ByteTriplet >>= 6;
		EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F];
		ByteTriplet >>= 6;
		EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F];
		ByteTriplet >>= 6;
		EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F];

		// Now we can append this buffer to our destination string
		EncodedBytes += 4;
	}

	// Since this algorithm operates on blocks, we may need to pad the last chunks
	if (Length > 0)
	{
		uint8_t A = *Source++;
		uint8_t B = 0;
		uint8_t C = 0;
		// Grab the second character if it is a 2 uint8_t finish
		if (Length == 2)
		{
			B = *Source;
		}
		uint32_t ByteTriplet = A << 16 | B << 8 | C;
		// Pad with = to make a 4 uint8_t chunk
		EncodedBytes[3] = '=';
		ByteTriplet >>= 6;
		// If there's only one 1 uint8_t left in the source, then you need 2 pad chars
		if (Length == 1)
		{
			EncodedBytes[2] = '=';
		}
		else
		{
			EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F];
		}
		// Now encode the remaining bits the same way
		ByteTriplet >>= 6;
		EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F];
		ByteTriplet >>= 6;
		EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F];

		EncodedBytes += 4;
	}

	// Add a null terminator
	*EncodedBytes = 0;

	return uint32_t(EncodedBytes - Dest);
}

template uint32_t Base64::Encode<char>(const uint8_t* Source, uint32_t Length, char* Dest);
template uint32_t Base64::Encode<wchar_t>(const uint8_t* Source, uint32_t Length, wchar_t* Dest);

template<typename CharType>
bool
Base64::Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength)
{
	// Length must be a multiple of 4
	if (Length % 4 != 0)
	{
		OutLength = 0;
		return false;
	}

	uint8_t* DecodedBytes = Dest;

	// Process 4 encoded characters at a time, producing 3 decoded bytes
	while (Length > 0)
	{
		// Count padding characters at the end
		uint32_t PadCount = 0;
		if (Source[3] == '=')
		{
			PadCount++;
			if (Source[2] == '=')
			{
				PadCount++;
			}
		}

		// Look up each character in the decoding table
		uint8_t A = DecodingAlphabet[static_cast<uint8_t>(Source[0])];
		uint8_t B = DecodingAlphabet[static_cast<uint8_t>(Source[1])];
		uint8_t C = (PadCount >= 2) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[2])];
		uint8_t D = (PadCount >= 1) ? 0 : DecodingAlphabet[static_cast<uint8_t>(Source[3])];

		// Check for invalid characters (0xFF means not in the base64 alphabet)
		if (A == 0xFF || B == 0xFF || C == 0xFF || D == 0xFF)
		{
			OutLength = 0;
			return false;
		}

		// Reconstruct the 24-bit value from 4 6-bit chunks
		uint32_t ByteTriplet = (A << 18) | (B << 12) | (C << 6) | D;

		// Extract the 3 bytes
		*DecodedBytes++ = static_cast<uint8_t>(ByteTriplet >> 16);
		if (PadCount < 2)
		{
			*DecodedBytes++ = static_cast<uint8_t>((ByteTriplet >> 8) & 0xFF);
		}
		if (PadCount < 1)
		{
			*DecodedBytes++ = static_cast<uint8_t>(ByteTriplet & 0xFF);
		}

		Source += 4;
		Length -= 4;
	}

	OutLength = uint32_t(DecodedBytes - Dest);
	return true;
}

template bool Base64::Decode<char>(const char* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength);
template bool Base64::Decode<wchar_t>(const wchar_t* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength);

//////////////////////////////////////////////////////////////////////////
//
// Testing related code follows...
//

#if ZEN_WITH_TESTS

using namespace std::string_literals;

TEST_CASE("Base64")
{
	auto EncodeString = [](std::string_view Input) -> std::string {
		std::string Result;
		Result.resize(Base64::GetEncodedDataSize(uint32_t(Input.size())));
		Base64::Encode(reinterpret_cast<const uint8_t*>(Input.data()), uint32_t(Input.size()), Result.data());
		return Result;
	};

	auto DecodeString = [](std::string_view Input) -> std::string {
		std::string Result;
		Result.resize(Base64::GetMaxDecodedDataSize(uint32_t(Input.size())));
		uint32_t DecodedLength = 0;
		bool	 Success = Base64::Decode(Input.data(), uint32_t(Input.size()), reinterpret_cast<uint8_t*>(Result.data()), DecodedLength);
		CHECK(Success);
		Result.resize(DecodedLength);
		return Result;
	};

	SUBCASE("Encode")
	{
		CHECK(EncodeString("") == ""s);
		CHECK(EncodeString("f") == "Zg=="s);
		CHECK(EncodeString("fo") == "Zm8="s);
		CHECK(EncodeString("foo") == "Zm9v"s);
		CHECK(EncodeString("foob") == "Zm9vYg=="s);
		CHECK(EncodeString("fooba") == "Zm9vYmE="s);
		CHECK(EncodeString("foobar") == "Zm9vYmFy"s);
	}

	SUBCASE("Decode")
	{
		CHECK(DecodeString("") == ""s);
		CHECK(DecodeString("Zg==") == "f"s);
		CHECK(DecodeString("Zm8=") == "fo"s);
		CHECK(DecodeString("Zm9v") == "foo"s);
		CHECK(DecodeString("Zm9vYg==") == "foob"s);
		CHECK(DecodeString("Zm9vYmE=") == "fooba"s);
		CHECK(DecodeString("Zm9vYmFy") == "foobar"s);
	}

	SUBCASE("RoundTrip")
	{
		auto RoundTrip = [&](const std::string& Input) {
			std::string Encoded = EncodeString(Input);
			std::string Decoded = DecodeString(Encoded);
			CHECK(Decoded == Input);
		};

		RoundTrip("Hello, World!");
		RoundTrip("Base64 encoding test with various lengths");
		RoundTrip("A");
		RoundTrip("AB");
		RoundTrip("ABC");
		RoundTrip("ABCD");
		RoundTrip("\x00\x01\x02\xff\xfe\xfd"s);
	}

	SUBCASE("BinaryRoundTrip")
	{
		// Test with all byte values 0-255
		uint8_t AllBytes[256];
		for (int i = 0; i < 256; ++i)
		{
			AllBytes[i] = static_cast<uint8_t>(i);
		}

		char Encoded[Base64::GetEncodedDataSize(256) + 1];
		Base64::Encode(AllBytes, 256, Encoded);

		uint8_t	 Decoded[256];
		uint32_t DecodedLength = 0;
		bool	 Success	   = Base64::Decode(Encoded, uint32_t(strlen(Encoded)), Decoded, DecodedLength);
		CHECK(Success);
		CHECK(DecodedLength == 256);
		CHECK(memcmp(AllBytes, Decoded, 256) == 0);
	}

	SUBCASE("DecodeInvalidInput")
	{
		uint8_t	 Dest[64];
		uint32_t OutLength = 0;

		// Length not a multiple of 4
		CHECK_FALSE(Base64::Decode("abc", 3u, Dest, OutLength));

		// Invalid character
		CHECK_FALSE(Base64::Decode("ab!d", 4u, Dest, OutLength));
	}

	SUBCASE("EncodedDataSize")
	{
		CHECK(Base64::GetEncodedDataSize(0) == 0);
		CHECK(Base64::GetEncodedDataSize(1) == 4);
		CHECK(Base64::GetEncodedDataSize(2) == 4);
		CHECK(Base64::GetEncodedDataSize(3) == 4);
		CHECK(Base64::GetEncodedDataSize(4) == 8);
		CHECK(Base64::GetEncodedDataSize(5) == 8);
		CHECK(Base64::GetEncodedDataSize(6) == 8);
	}

	SUBCASE("MaxDecodedDataSize")
	{
		CHECK(Base64::GetMaxDecodedDataSize(0) == 0);
		CHECK(Base64::GetMaxDecodedDataSize(4) == 3);
		CHECK(Base64::GetMaxDecodedDataSize(8) == 6);
		CHECK(Base64::GetMaxDecodedDataSize(12) == 9);
	}
}

#endif

}  // namespace zen