aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients/httpclientcurlhelpers.h
blob: 0605a30f681d65e52f4dfafd51905f026873e290 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
// Copyright Epic Games, Inc. All Rights Reserved.

#pragma once

// Shared helpers for curl-based HTTP client implementations (sync and async).
// This is an internal header, not part of the public API.

#include <zencore/string.h>

#include <zenhttp/httpclient.h>

ZEN_THIRD_PARTY_INCLUDES_START
#include <curl/curl.h>
ZEN_THIRD_PARTY_INCLUDES_END

#include <optional>
#include <string>
#include <utility>
#include <vector>

namespace zen {

//////////////////////////////////////////////////////////////////////////
//
// Error mapping

inline HttpClientErrorCode
MapCurlError(CURLcode Code)
{
	switch (Code)
	{
		case CURLE_OK:
			return HttpClientErrorCode::kOK;
		case CURLE_COULDNT_CONNECT:
			return HttpClientErrorCode::kConnectionFailure;
		case CURLE_COULDNT_RESOLVE_HOST:
			return HttpClientErrorCode::kHostResolutionFailure;
		case CURLE_COULDNT_RESOLVE_PROXY:
			return HttpClientErrorCode::kProxyResolutionFailure;
		case CURLE_RECV_ERROR:
			return HttpClientErrorCode::kNetworkReceiveError;
		case CURLE_SEND_ERROR:
			return HttpClientErrorCode::kNetworkSendFailure;
		case CURLE_OPERATION_TIMEDOUT:
			return HttpClientErrorCode::kOperationTimedOut;
		case CURLE_SSL_CONNECT_ERROR:
			return HttpClientErrorCode::kSSLConnectError;
		case CURLE_SSL_CERTPROBLEM:
			return HttpClientErrorCode::kSSLCertificateError;
		case CURLE_PEER_FAILED_VERIFICATION:
			return HttpClientErrorCode::kSSLCACertError;
		case CURLE_SSL_CIPHER:
		case CURLE_SSL_ENGINE_NOTFOUND:
		case CURLE_SSL_ENGINE_SETFAILED:
			return HttpClientErrorCode::kGenericSSLError;
		case CURLE_ABORTED_BY_CALLBACK:
			return HttpClientErrorCode::kRequestCancelled;
		default:
			return HttpClientErrorCode::kOtherError;
	}
}

//////////////////////////////////////////////////////////////////////////
//
// Curl callback data structures and callbacks

struct CurlWriteCallbackData
{
	std::string*		   Body					= nullptr;
	std::function<bool()>* CheckIfAbortFunction = nullptr;
};

inline size_t
CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
{
	auto*  Data		  = static_cast<CurlWriteCallbackData*>(UserData);
	size_t TotalBytes = Size * Nmemb;

	if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
	{
		return 0;  // Signal abort to curl
	}

	Data->Body->append(Ptr, TotalBytes);
	return TotalBytes;
}

struct CurlHeaderCallbackData
{
	std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
};

// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
inline std::optional<std::pair<std::string_view, std::string_view>>
ParseHeaderLine(std::string_view Line)
{
	while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
	{
		Line.remove_suffix(1);
	}

	if (Line.empty())
	{
		return std::nullopt;
	}

	size_t ColonPos = Line.find(':');
	if (ColonPos == std::string_view::npos)
	{
		return std::nullopt;
	}

	std::string_view Key   = Line.substr(0, ColonPos);
	std::string_view Value = Line.substr(ColonPos + 1);

	while (!Key.empty() && Key.back() == ' ')
	{
		Key.remove_suffix(1);
	}
	while (!Value.empty() && Value.front() == ' ')
	{
		Value.remove_prefix(1);
	}

	return std::pair{Key, Value};
}

inline size_t
CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
{
	auto*  Data		  = static_cast<CurlHeaderCallbackData*>(UserData);
	size_t TotalBytes = Size * Nmemb;

	if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
	{
		auto& [Key, Value] = *Header;
		Data->Headers->emplace_back(std::string(Key), std::string(Value));
	}

	return TotalBytes;
}

struct CurlReadCallbackData
{
	const uint8_t*		   DataPtr				= nullptr;
	size_t				   DataSize				= 0;
	size_t				   Offset				= 0;
	std::function<bool()>* CheckIfAbortFunction = nullptr;
};

inline size_t
CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
{
	auto*  Data	   = static_cast<CurlReadCallbackData*>(UserData);
	size_t MaxRead = Size * Nmemb;

	if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
	{
		return CURL_READFUNC_ABORT;
	}

	size_t Remaining = Data->DataSize - Data->Offset;
	size_t ToRead	 = std::min(MaxRead, Remaining);

	if (ToRead > 0)
	{
		memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
		Data->Offset += ToRead;
	}

	return ToRead;
}

//////////////////////////////////////////////////////////////////////////
//
// URL and header construction

inline void
AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
{
	static constexpr char	  HexDigits[] = "0123456789ABCDEF";
	static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");

	for (char C : Input)
	{
		if (Unreserved.Contains(C))
		{
			Out.Append(C);
		}
		else
		{
			uint8_t Byte	   = static_cast<uint8_t>(C);
			char	Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
			Out.Append(std::string_view(Encoded, 3));
		}
	}
}

inline void
BuildUrlWithParameters(StringBuilderBase&			  Url,
					   std::string_view				  BaseUrl,
					   std::string_view				  ResourcePath,
					   const HttpClient::KeyValueMap& Parameters)
{
	Url.Append(BaseUrl);
	Url.Append(ResourcePath);

	if (!Parameters->empty())
	{
		char Separator = '?';
		for (const auto& [Key, Value] : *Parameters)
		{
			Url.Append(Separator);
			AppendUrlEncoded(Url, Key);
			Url.Append('=');
			AppendUrlEncoded(Url, Value);
			Separator = '&';
		}
	}
}

inline std::pair<std::string, std::string>
HeaderContentType(ZenContentType ContentType)
{
	return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
}

inline curl_slist*
BuildHeaderList(const HttpClient::KeyValueMap&							AdditionalHeader,
				std::string_view										SessionId,
				const std::optional<std::string>&						AccessToken,
				const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
{
	curl_slist* Headers = nullptr;

	for (const auto& [Key, Value] : *AdditionalHeader)
	{
		ExtendableStringBuilder<64> HeaderLine;
		HeaderLine << Key << ": " << Value;
		Headers = curl_slist_append(Headers, HeaderLine.c_str());
	}

	if (!SessionId.empty())
	{
		ExtendableStringBuilder<64> SessionHeader;
		SessionHeader << "UE-Session: " << SessionId;
		Headers = curl_slist_append(Headers, SessionHeader.c_str());
	}

	if (AccessToken.has_value())
	{
		ExtendableStringBuilder<128> AuthHeader;
		AuthHeader << "Authorization: " << AccessToken.value();
		Headers = curl_slist_append(Headers, AuthHeader.c_str());
	}

	for (const auto& [Key, Value] : ExtraHeaders)
	{
		ExtendableStringBuilder<128> HeaderLine;
		HeaderLine << Key << ": " << Value;
		Headers = curl_slist_append(Headers, HeaderLine.c_str());
	}

	return Headers;
}

inline HttpClient::KeyValueMap
BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
{
	HttpClient::KeyValueMap HeaderMap;
	for (const auto& [Key, Value] : Headers)
	{
		HeaderMap->insert_or_assign(Key, Value);
	}
	return HeaderMap;
}

// Scans response headers for Content-Type and applies it to the buffer.
inline void
ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
{
	for (const auto& [Key, Value] : Headers)
	{
		if (StrCaseCompare(Key, "Content-Type") == 0)
		{
			Buffer.SetContentType(ParseContentType(Value));
			break;
		}
	}
}

}  // namespace zen