aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2025-10-06 22:33:00 +0200
committerStefan Boberg <[email protected]>2025-10-06 22:33:00 +0200
commit1383dbdc563d90c170ab30ba622ee44e2e37e723 (patch)
tree59777db60000fe2ab2334f05776fb9ded4ca41fb /src/zenhttp
parentMerge branch 'main' into sb/rpc-analysis (diff)
parent5.7.6 (diff)
downloadzen-1383dbdc563d90c170ab30ba622ee44e2e37e723.tar.xz
zen-1383dbdc563d90c170ab30ba622ee44e2e37e723.zip
Merge remote-tracking branch 'origin/main' into sb/rpc-analysis
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/auth/authmgr.cpp106
-rw-r--r--src/zenhttp/auth/authservice.cpp2
-rw-r--r--src/zenhttp/auth/oidc.cpp40
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp474
-rw-r--r--src/zenhttp/clients/httpclientcommon.h147
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp1035
-rw-r--r--src/zenhttp/clients/httpclientcpr.h151
-rw-r--r--src/zenhttp/httpclient.cpp898
-rw-r--r--src/zenhttp/httpclientauth.cpp212
-rw-r--r--src/zenhttp/httpserver.cpp314
-rw-r--r--src/zenhttp/include/zenhttp/auth/oidc.h3
-rw-r--r--src/zenhttp/include/zenhttp/cprutils.h86
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h82
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h116
-rw-r--r--src/zenhttp/include/zenhttp/httpclientauth.h36
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h25
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h164
-rw-r--r--src/zenhttp/packageformat.cpp936
-rw-r--r--src/zenhttp/servers/httpasio.cpp253
-rw-r--r--src/zenhttp/servers/httpmulti.cpp7
-rw-r--r--src/zenhttp/servers/httpparser.cpp169
-rw-r--r--src/zenhttp/servers/httpparser.h103
-rw-r--r--src/zenhttp/servers/httpplugin.cpp114
-rw-r--r--src/zenhttp/servers/httpsys.cpp119
-rw-r--r--src/zenhttp/servers/httptracer.cpp37
-rw-r--r--src/zenhttp/servers/httptracer.h26
-rw-r--r--src/zenhttp/transports/asiotransport.cpp2
-rw-r--r--src/zenhttp/transports/dlltransport.cpp92
-rw-r--r--src/zenhttp/transports/dlltransport.h2
-rw-r--r--src/zenhttp/transports/winsocktransport.cpp26
-rw-r--r--src/zenhttp/xmake.lua2
-rw-r--r--src/zenhttp/zenhttp.cpp3
32 files changed, 4576 insertions, 1206 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp
index 18568a21d..209276621 100644
--- a/src/zenhttp/auth/authmgr.cpp
+++ b/src/zenhttp/auth/authmgr.cpp
@@ -2,12 +2,14 @@
#include "zenhttp/auth/authmgr.h"
+#include <zencore/basicfile.h>
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compactbinaryutil.h>
#include <zencore/crypto.h>
#include <zencore/filesystem.h>
#include <zencore/logging.h>
+#include <zencore/trace.h>
#include <zenhttp/auth/oidc.h>
#include <condition_variable>
@@ -28,6 +30,8 @@ namespace details {
const AesIV128Bit& IV,
std::optional<std::string>& Reason)
{
+ ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile");
+
FileContents Result = ReadFile(Path);
if (Result.ErrorCode)
@@ -61,6 +65,8 @@ namespace details {
const AesIV128Bit& IV,
std::optional<std::string>& Reason)
{
+ ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile");
+
if (FileData.GetSize() == 0)
{
return;
@@ -76,7 +82,7 @@ namespace details {
return;
}
- WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize()));
+ TemporaryFile::SafeWriteFile(Path, EncryptedView);
}
} // namespace details
@@ -99,11 +105,7 @@ public:
virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
{
- if (OpenIdProviderExist(Params.Name))
- {
- ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
- return;
- }
+ ZEN_TRACE_CPU("AuthMgr::AddOpenIdProvider");
if (Params.Name.empty())
{
@@ -111,8 +113,26 @@ public:
return;
}
- std::unique_ptr<OidcClient> Client =
- std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+ {
+ std::unique_lock _(m_ProviderMutex);
+ if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end())
+ {
+ OpenIdProvider& ExistingProvider = *It->second;
+ if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url)
+ {
+ ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name);
+ return;
+ }
+ else
+ {
+ m_OpenIdProviders.erase(It);
+ m_OpenIdTokens.erase(std::string(Params.Name));
+ ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name);
+ }
+ }
+ }
+
+ RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}));
if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
{
@@ -146,6 +166,8 @@ public:
virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
{
+ ZEN_TRACE_CPU("AuthMgr::AddOpenIdToken");
+
if (Params.ProviderName.empty())
{
ZEN_WARN("trying add OpenID token with invalid provider name");
@@ -208,29 +230,45 @@ public:
}
private:
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ RefPtr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ TimePoint ExpireTime{};
+ };
+
bool OpenIdProviderExist(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
-
return m_OpenIdProviders.contains(std::string(ProviderName));
}
- OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ OpenIdProvider GetOpenIdProvider(std::string_view ProviderName)
{
std::unique_lock _(m_ProviderMutex);
- return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ return *m_OpenIdProviders[std::string(ProviderName)];
}
OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
{
- if (OpenIdProviderExist(ProviderName) == false)
+ ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken");
+
+ RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient;
+ if (!Client)
{
return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
}
- OidcClient& Client = GetOpenIdClient(ProviderName);
-
- return Client.RefreshToken(RefreshToken);
+ return Client->RefreshToken(RefreshToken);
}
void Shutdown()
@@ -241,6 +279,7 @@ private:
void LoadState()
{
+ ZEN_TRACE_CPU("AuthMgrImpl::LoadState");
try
{
std::optional<std::string> Reason;
@@ -258,15 +297,14 @@ private:
return;
}
- const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
-
- if (ValidationError != CbValidateError::None)
+ CbValidateError ValidationError;
+ if (CbObject AuthState = ValidateAndReadCompactBinaryObject(std::move(Buffer), ValidationError);
+ ValidationError != CbValidateError::None)
{
- ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ ZEN_WARN("load serialized state FAILED, reason '{}'", ToString(ValidationError));
return;
}
-
- if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ else
{
for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
{
@@ -295,7 +333,7 @@ private:
}
}
}
- catch (std::exception& Err)
+ catch (const std::exception& Err)
{
ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());
@@ -313,6 +351,8 @@ private:
void SaveState()
{
+ ZEN_TRACE_CPU("AuthMgr::SaveState");
+
try
{
CbObjectWriter AuthState;
@@ -352,7 +392,7 @@ private:
AuthState.EndArray();
}
- std::filesystem::create_directories(m_Config.RootDirectory);
+ CreateDirectories(m_Config.RootDirectory);
std::optional<std::string> Reason;
@@ -367,7 +407,7 @@ private:
ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value());
}
}
- catch (std::exception& Err)
+ catch (const std::exception& Err)
{
ZEN_WARN("serialize state FAILED, reason '{}'", Err.what());
}
@@ -474,22 +514,6 @@ private:
}
};
- struct OpenIdProvider
- {
- std::string Name;
- std::string Url;
- std::string ClientId;
- std::unique_ptr<OidcClient> HttpClient;
- };
-
- struct OpenIdToken
- {
- std::string IdentityToken;
- std::string RefreshToken;
- std::string AccessToken;
- TimePoint ExpireTime{};
- };
-
using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp
index 6ed587770..f89ca91da 100644
--- a/src/zenhttp/auth/authservice.cpp
+++ b/src/zenhttp/auth/authservice.cpp
@@ -56,7 +56,7 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
if (Ok)
{
- ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest);
+ ServerRequest.WriteResponse(HttpResponseCode::OK);
}
else
{
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp
index 318110c7d..38e7586ad 100644
--- a/src/zenhttp/auth/oidc.cpp
+++ b/src/zenhttp/auth/oidc.cpp
@@ -1,9 +1,9 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "zenhttp/auth/oidc.h"
+#include <zenhttp/httpclient.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <cpr/cpr.h>
#include <fmt/format.h>
#include <json11.hpp>
ZEN_THIRD_PARTY_INCLUDES_END
@@ -41,27 +41,21 @@ OidcClient::OidcClient(const OidcClient::Options& Options)
OidcClient::InitResult
OidcClient::Initialize()
{
- ExtendableStringBuilder<256> Uri;
- Uri << m_BaseUrl << "/.well-known/openid-configuration"sv;
+ HttpClient Http{m_BaseUrl};
+ HttpClient::Response Response = Http.Get("/.well-known/openid-configuration"sv);
- cpr::Session Session;
-
- Session.SetOption(cpr::Url{Uri.c_str()});
-
- cpr::Response Response = Session.Get();
-
- if (Response.error)
+ if (!Response)
{
- return {.Reason = std::move(Response.error.message)};
+ return {.Reason = Response.ErrorMessage("")};
}
- if (Response.status_code != 200)
+ if (Response.StatusCode != HttpResponseCode::OK)
{
- return {.Reason = std::move(Response.reason)};
+ return {.Reason = std::string{ToString(Response.StatusCode)}};
}
std::string JsonError;
- json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError);
if (JsonError.empty() == false)
{
@@ -89,26 +83,24 @@ OidcClient::RefreshToken(std::string_view RefreshToken)
{
const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
- cpr::Session Session;
+ HttpClient Http{m_Config.TokenEndpoint};
- Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()});
- Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}});
- Session.SetBody(cpr::Body{Body.data(), Body.size()});
+ HttpClient::KeyValueMap Headers{{"Content-Type", "application/x-www-form-urlencoded"}};
- cpr::Response Response = Session.Post();
+ HttpClient::Response Response = Http.Post("", IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}), Headers);
- if (Response.error)
+ if (!Response)
{
- return {.Reason = std::move(Response.error.message)};
+ return {.Reason = std::string{Response.ErrorMessage("")}};
}
- if (Response.status_code != 200)
+ if (Response.StatusCode != HttpResponseCode::OK)
{
- return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)};
+ return {.Reason = fmt::format("{} ({})", ToString(Response.StatusCode), Response.AsText())};
}
std::string JsonError;
- json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError);
if (JsonError.empty() == false)
{
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
new file mode 100644
index 000000000..8e5136dff
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -0,0 +1,474 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcommon.h"
+
+#include <fmt/format.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory/memory.h>
+#include <zencore/windows.h>
+#include <gsl/gsl-lite.hpp>
+
+#if ZEN_WITH_TESTS
+# include <zencore/basicfile.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif // ZEN_WITH_TESTS
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+# include <fcntl.h>
+# include <sys/stat.h>
+# include <unistd.h>
+#endif
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace detail {
+
+ static std::atomic_uint32_t TempFileBaseIndex;
+
+ TempPayloadFile::TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {}
+ TempPayloadFile::~TempPayloadFile()
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::Close");
+ try
+ {
+ if (m_FileHandle)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ // Mark file for deletion when final handle is closed
+ FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
+
+ SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+ BOOL Success = CloseHandle(m_FileHandle);
+#else
+ std::error_code Ec;
+ std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Error reported on get file path from handle {} for temp payload unlink operation, reason '{}'",
+ m_FileHandle,
+ Ec.message());
+ }
+ else
+ {
+ unlink(FilePath.c_str());
+ }
+ int Fd = int(uintptr_t(m_FileHandle));
+ bool Success = (close(Fd) == 0);
+#endif
+ if (!Success)
+ {
+ ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString());
+ }
+
+ m_FileHandle = nullptr;
+ }
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what());
+ }
+ }
+
+ std::error_code TempPayloadFile::Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize)
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::Open");
+ ZEN_ASSERT(m_FileHandle == nullptr);
+
+ std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) |
+ detail::TempFileBaseIndex.fetch_add(1);
+
+ std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex);
+#if ZEN_PLATFORM_WINDOWS
+ LPCWSTR lpFileName = FileName.c_str();
+ const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE);
+ const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
+ LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr;
+ const DWORD dwCreationDisposition = CREATE_ALWAYS;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL;
+ const HANDLE hTemplateFile = nullptr;
+ const HANDLE FileHandle = CreateFile(lpFileName,
+ dwDesiredAccess,
+ dwShareMode,
+ lpSecurityAttributes,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC;
+ int Fd = open(FileName.c_str(), OpenFlags, 0666);
+ if (Fd < 0)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+ fchmod(Fd, 0666);
+
+ void* FileHandle = (void*)(uintptr_t(Fd));
+#endif // ZEN_PLATFORM_WINDOWS
+ m_FileHandle = FileHandle;
+
+ PrepareFileForScatteredWrite(m_FileHandle, FinalSize);
+
+ return {};
+ }
+
+ std::error_code TempPayloadFile::Write(std::string_view DataString)
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::Write");
+ const uint8_t* DataPtr = (const uint8_t*)DataString.data();
+ size_t DataSize = DataString.size();
+ if (DataSize >= CacheBufferSize)
+ {
+ std::error_code Ec = Flush();
+ if (Ec)
+ {
+ return Ec;
+ }
+ return AppendData(DataPtr, DataSize);
+ }
+ size_t CopySize = Min(DataSize, CacheBufferSize - m_CacheBufferOffset);
+ memcpy(&m_CacheBuffer[m_CacheBufferOffset], DataPtr, CopySize);
+ m_CacheBufferOffset += CopySize;
+ DataSize -= CopySize;
+ if (m_CacheBufferOffset == CacheBufferSize)
+ {
+ AppendData(m_CacheBuffer, CacheBufferSize);
+ if (DataSize > 0)
+ {
+ ZEN_ASSERT(DataSize < CacheBufferSize);
+ memcpy(m_CacheBuffer, DataPtr + CopySize, DataSize);
+ }
+ m_CacheBufferOffset = DataSize;
+ }
+ else
+ {
+ ZEN_ASSERT(DataSize == 0);
+ }
+ return {};
+ }
+
+ IoBuffer TempPayloadFile::DetachToIoBuffer()
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::DetachToIoBuffer");
+ if (std::error_code Ec = Flush(); Ec)
+ {
+ ThrowSystemError(Ec.value(), Ec.message());
+ }
+ ZEN_ASSERT(m_FileHandle != nullptr);
+ void* FileHandle = m_FileHandle;
+ IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true);
+ Buffer.SetDeleteOnClose(true);
+ m_FileHandle = 0;
+ m_WriteOffset = 0;
+ return Buffer;
+ }
+
+ IoBuffer TempPayloadFile::BorrowIoBuffer()
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::BorrowIoBuffer");
+ if (std::error_code Ec = Flush(); Ec)
+ {
+ ThrowSystemError(Ec.value(), Ec.message());
+ }
+ ZEN_ASSERT(m_FileHandle != nullptr);
+ void* FileHandle = m_FileHandle;
+ IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, 0, m_WriteOffset);
+ return Buffer;
+ }
+
+ void TempPayloadFile::ResetWritePos(uint64_t WriteOffset)
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::ResetWritePos");
+ Flush();
+ m_WriteOffset = WriteOffset;
+ }
+
+ std::error_code TempPayloadFile::Flush()
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::Flush");
+ if (m_CacheBufferOffset == 0)
+ {
+ return {};
+ }
+ std::error_code Res = AppendData(m_CacheBuffer, m_CacheBufferOffset);
+ m_CacheBufferOffset = 0;
+ return Res;
+ }
+
+ std::error_code TempPayloadFile::AppendData(const void* Data, uint64_t Size)
+ {
+ ZEN_TRACE_CPU("TempPayloadFile::AppendData");
+ ZEN_ASSERT(m_FileHandle != nullptr);
+ const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
+
+ while (Size)
+ {
+ const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize);
+ uint64_t NumberOfBytesWritten = 0;
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32);
+
+ DWORD dwNumberOfBytesWritten = 0;
+
+ BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl);
+ if (Success)
+ {
+ NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten);
+ }
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset);
+ bool Success = (BytesWritten > 0);
+ if (Success)
+ {
+ NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten);
+ }
+#endif
+
+ if (!Success)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+
+ Size -= NumberOfBytesWritten;
+ m_WriteOffset += NumberOfBytesWritten;
+ Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten;
+ }
+ return {};
+ }
+
+ BufferedReadFileStream::BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize)
+ : m_FileHandle(FileHandle)
+ , m_FileSize(FileSize)
+ , m_FileEnd(FileOffset + FileSize)
+ , m_BufferSize(Min(BufferSize, FileSize))
+ , m_FileOffset(FileOffset)
+ {
+ }
+
+ BufferedReadFileStream::~BufferedReadFileStream() { Memory::Free(m_Buffer); }
+ void BufferedReadFileStream::Read(void* Data, uint64_t Size)
+ {
+ ZEN_ASSERT(Data != nullptr);
+ if (Size > m_BufferSize)
+ {
+ Read(Data, Size, m_FileOffset);
+ m_FileOffset += Size;
+ return;
+ }
+ uint8_t* WritePtr = ((uint8_t*)Data);
+ uint64_t Begin = m_FileOffset;
+ uint64_t End = m_FileOffset + Size;
+ ZEN_ASSERT(m_FileOffset >= m_BufferStart);
+ if (m_FileOffset < m_BufferEnd)
+ {
+ ZEN_ASSERT(m_Buffer != nullptr);
+ uint64_t Count = Min(m_BufferEnd, End) - m_FileOffset;
+ memcpy(WritePtr + Begin - m_FileOffset, m_Buffer + Begin - m_BufferStart, Count);
+ Begin += Count;
+ if (Begin == End)
+ {
+ m_FileOffset = End;
+ return;
+ }
+ }
+ if (End == m_FileEnd)
+ {
+ Read(WritePtr + Begin - m_FileOffset, End - Begin, Begin);
+ }
+ else
+ {
+ if (!m_Buffer)
+ {
+ m_BufferSize = Min(m_FileEnd - m_FileOffset, m_BufferSize);
+ m_Buffer = (uint8_t*)Memory::Alloc(gsl::narrow<size_t>(m_BufferSize));
+ }
+ m_BufferStart = Begin;
+ m_BufferEnd = Min(Begin + m_BufferSize, m_FileEnd);
+ Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart);
+ uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart;
+ memcpy(WritePtr + Begin - m_FileOffset, m_Buffer, Count);
+ ZEN_ASSERT(Begin + Count == End);
+ }
+ m_FileOffset = End;
+ }
+
+ void BufferedReadFileStream::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset)
+ {
+ const uint64_t MaxChunkSize = 1u * 1024 * 1024;
+ std::error_code Ec;
+ ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec);
+
+ if (Ec)
+ {
+ std::error_code DummyEc;
+ throw std::system_error(
+ Ec,
+ fmt::format("HttpClient::BufferedReadFileStream ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})",
+ FileOffset,
+ BytesToRead,
+ PathFromHandle(m_FileHandle, DummyEc).generic_string(),
+ m_FileSize));
+ }
+ }
+
+ CompositeBufferReadStream::CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize)
+ : m_Data(Data)
+ , m_BufferSize(BufferSize)
+ , m_SegmentIndex(0)
+ , m_BytesLeftInSegment(0)
+ {
+ }
+ uint64_t CompositeBufferReadStream::Read(void* Data, uint64_t Size)
+ {
+ uint64_t Result = 0;
+ uint8_t* WritePtr = (uint8_t*)Data;
+ while ((Size > 0) && (m_SegmentIndex < m_Data.GetSegments().size()))
+ {
+ if (m_BytesLeftInSegment == 0)
+ {
+ const SharedBuffer& Segment = m_Data.GetSegments()[m_SegmentIndex];
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Segment.AsIoBuffer().GetFileReference(FileRef))
+ {
+ m_SegmentDiskBuffer = std::make_unique<BufferedReadFileStream>(FileRef.FileHandle,
+ FileRef.FileChunkOffset,
+ FileRef.FileChunkSize,
+ m_BufferSize);
+ }
+ else
+ {
+ m_SegmentMemoryBuffer = Segment.GetView();
+ }
+ m_BytesLeftInSegment = Segment.GetSize();
+ }
+ uint64_t BytesToRead = Min(m_BytesLeftInSegment, Size);
+ if (m_SegmentDiskBuffer)
+ {
+ m_SegmentDiskBuffer->Read(WritePtr, BytesToRead);
+ }
+ else
+ {
+ ZEN_ASSERT_SLOW(m_SegmentMemoryBuffer.GetSize() >= BytesToRead);
+ memcpy(WritePtr, m_SegmentMemoryBuffer.GetData(), BytesToRead);
+ m_SegmentMemoryBuffer.MidInline(BytesToRead);
+ }
+ WritePtr += BytesToRead;
+ Size -= BytesToRead;
+ Result += BytesToRead;
+
+ m_BytesLeftInSegment -= BytesToRead;
+ if (m_BytesLeftInSegment == 0)
+ {
+ m_SegmentDiskBuffer.reset();
+ m_SegmentMemoryBuffer.Reset();
+ m_SegmentIndex++;
+ }
+ }
+ return Result;
+ }
+
+} // namespace detail
+
+} // namespace zen
+
+#if ZEN_WITH_TESTS
+namespace zen {
+
+namespace testutil {
+ IoHash HashComposite(const CompositeBuffer& Payload)
+ {
+ IoHashStream Hasher;
+ const uint64_t PayloadSize = Payload.GetSize();
+ std::vector<uint8_t> Buffer(64u * 1024u);
+ detail::CompositeBufferReadStream Stream(Payload, 137u * 1024u);
+ for (uint64_t Offset = 0; Offset < PayloadSize;)
+ {
+ uint64_t Count = Min(64u * 1024u, PayloadSize - Offset);
+ Stream.Read(Buffer.data(), Count);
+ Hasher.Append(Buffer.data(), Count);
+ Offset += Count;
+ }
+ return Hasher.GetHash();
+ };
+
+ IoHash HashFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize)
+ {
+ IoHashStream Hasher;
+ std::vector<uint8_t> Buffer(64u * 1024u);
+ detail::BufferedReadFileStream Stream(FileHandle, FileOffset, FileSize, 137u * 1024u);
+ for (uint64_t Offset = 0; Offset < FileSize;)
+ {
+ uint64_t Count = Min(64u * 1024u, FileSize - Offset);
+ Stream.Read(Buffer.data(), Count);
+ Hasher.Append(Buffer.data(), Count);
+ Offset += Count;
+ }
+ return Hasher.GetHash();
+ }
+
+} // namespace testutil
+
+TEST_CASE("BufferedReadFileStream")
+{
+ ScopedTemporaryDirectory TmpDir;
+
+ IoBuffer DiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer1");
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ CHECK(DiskBuffer.GetFileReference(FileRef));
+ CHECK_EQ(IoHash::HashBuffer(DiskBuffer), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize));
+
+ IoBuffer Partial(DiskBuffer, 37 * 1024, 512 * 1024);
+ CHECK(Partial.GetFileReference(FileRef));
+ CHECK_EQ(IoHash::HashBuffer(Partial), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize));
+
+ IoBuffer SmallDiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(63 * 1024)), TmpDir.Path() / "diskbuffer2");
+ CHECK(SmallDiskBuffer.GetFileReference(FileRef));
+ CHECK_EQ(IoHash::HashBuffer(SmallDiskBuffer),
+ testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize));
+}
+
+TEST_CASE("CompositeBufferReadStream")
+{
+ ScopedTemporaryDirectory TmpDir;
+
+ IoBuffer MemoryBuffer1 = CreateRandomBlob(64);
+ CHECK_EQ(IoHash::HashBuffer(MemoryBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer1))));
+
+ IoBuffer MemoryBuffer2 = CreateRandomBlob(561 * 1024);
+ CHECK_EQ(IoHash::HashBuffer(MemoryBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer2))));
+
+ IoBuffer DiskBuffer1 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(267 * 3 * 1024)), TmpDir.Path() / "diskbuffer1");
+ CHECK_EQ(IoHash::HashBuffer(DiskBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer1))));
+
+ IoBuffer DiskBuffer2 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(3 * 1024)), TmpDir.Path() / "diskbuffer2");
+ CHECK_EQ(IoHash::HashBuffer(DiskBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer2))));
+
+ IoBuffer DiskBuffer3 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer3");
+ CHECK_EQ(IoHash::HashBuffer(DiskBuffer3), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer3))));
+
+ CompositeBuffer Data(SharedBuffer(std::move(MemoryBuffer1)),
+ SharedBuffer(std::move(DiskBuffer1)),
+ SharedBuffer(std::move(DiskBuffer2)),
+ SharedBuffer(std::move(MemoryBuffer2)),
+ SharedBuffer(std::move(DiskBuffer3)));
+ CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
+}
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h
new file mode 100644
index 000000000..9060cde48
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcommon.h
@@ -0,0 +1,147 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compositebuffer.h>
+#include <zencore/trace.h>
+
+#include <zenhttp/httpclient.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+class HttpClientBase
+{
+public:
+ HttpClientBase(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {});
+ virtual ~HttpClientBase() = 0;
+
+ using Response = HttpClient::Response;
+ using KeyValueMap = HttpClient::KeyValueMap;
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) = 0;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) = 0;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) = 0;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) = 0;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}) = 0;
+
+ LoggerRef Log() { return m_Log; }
+ std::string_view GetBaseUri() const { return m_BaseUri; }
+ std::string_view GetSessionId() const { return m_SessionId; }
+
+ bool Authenticate();
+
+protected:
+ LoggerRef m_Log;
+ std::string m_BaseUri;
+ std::string m_SessionId;
+ const HttpClientSettings m_ConnectionSettings;
+
+ const std::optional<HttpClientAccessToken> GetAccessToken();
+
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+};
+
+namespace detail {
+
+ class TempPayloadFile
+ {
+ public:
+ TempPayloadFile(const TempPayloadFile&) = delete;
+ TempPayloadFile& operator=(const TempPayloadFile&) = delete;
+
+ TempPayloadFile();
+ ~TempPayloadFile();
+
+ std::error_code Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize);
+ std::error_code Write(std::string_view DataString);
+ IoBuffer DetachToIoBuffer();
+ IoBuffer BorrowIoBuffer();
+ inline uint64_t GetSize() const { return m_WriteOffset; }
+ void ResetWritePos(uint64_t WriteOffset);
+
+ private:
+ std::error_code Flush();
+ std::error_code AppendData(const void* Data, uint64_t Size);
+
+ void* m_FileHandle;
+ std::uint64_t m_WriteOffset;
+ static constexpr uint64_t CacheBufferSize = 512u * 1024u;
+ uint8_t m_CacheBuffer[CacheBufferSize];
+ std::uint64_t m_CacheBufferOffset = 0;
+ };
+
+ class BufferedReadFileStream
+ {
+ public:
+ BufferedReadFileStream(const BufferedReadFileStream&) = delete;
+ BufferedReadFileStream& operator=(const BufferedReadFileStream&) = delete;
+
+ BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize);
+ ~BufferedReadFileStream();
+
+ void Read(void* Data, uint64_t Size);
+
+ private:
+ void Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset);
+
+ void* m_FileHandle = nullptr;
+ const uint64_t m_FileSize = 0;
+ const uint64_t m_FileEnd = 0;
+ uint64_t m_BufferSize = 0;
+ uint8_t* m_Buffer = nullptr;
+ uint64_t m_BufferStart = 0;
+ uint64_t m_BufferEnd = 0;
+ uint64_t m_FileOffset = 0;
+ };
+
+ class CompositeBufferReadStream
+ {
+ public:
+ CompositeBufferReadStream(const CompositeBufferReadStream&) = delete;
+ CompositeBufferReadStream& operator=(const CompositeBufferReadStream&) = delete;
+
+ CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize);
+ uint64_t Read(void* Data, uint64_t Size);
+
+ private:
+ const CompositeBuffer& m_Data;
+ const uint64_t m_BufferSize;
+ size_t m_SegmentIndex;
+ std::unique_ptr<BufferedReadFileStream> m_SegmentDiskBuffer;
+ MemoryView m_SegmentMemoryBuffer;
+ uint64_t m_BytesLeftInSegment;
+ };
+
+} // namespace detail
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
new file mode 100644
index 000000000..568106887
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -0,0 +1,1035 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcpr.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zenhttp/packageformat.h>
+
+namespace zen {
+
+HttpClientBase*
+CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings)
+{
+ return new CprHttpClient(BaseUri, ConnectionSettings);
+}
+
+static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+
+// If we want to support different HTTP client implementations then we'll need to make this more abstract
+
+HttpClientError::ResponseClass
+HttpClientError::GetResponseClass() const
+{
+ if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
+ {
+ switch ((cpr::ErrorCode)m_Error)
+ {
+ case cpr::ErrorCode::CONNECTION_FAILURE:
+ return ResponseClass::kHttpCantConnectError;
+ case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
+ return ResponseClass::kHttpNoHost;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return ResponseClass::kHttpTimeout;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return ResponseClass::kHttpSLLError;
+ default:
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ else if (IsHttpSuccessCode(m_ResponseCode))
+ {
+ return ResponseClass::kSuccess;
+ }
+ else
+ {
+ switch (m_ResponseCode)
+ {
+ case HttpResponseCode::Unauthorized:
+ return ResponseClass::kHttpUnauthorized;
+ case HttpResponseCode::NotFound:
+ return ResponseClass::kHttpNotFound;
+ case HttpResponseCode::Forbidden:
+ return ResponseClass::kHttpForbidden;
+ case HttpResponseCode::Conflict:
+ return ResponseClass::kHttpConflict;
+ case HttpResponseCode::InternalServerError:
+ return ResponseClass::kHttpInternalServerError;
+ case HttpResponseCode::ServiceUnavailable:
+ return ResponseClass::kHttpServiceUnavailable;
+ case HttpResponseCode::BadGateway:
+ return ResponseClass::kHttpBadGateway;
+ case HttpResponseCode::GatewayTimeout:
+ return ResponseClass::kHttpGatewayTimeout;
+ default:
+ if (m_ResponseCode >= HttpResponseCode::InternalServerError)
+ {
+ return ResponseClass::kHttpOtherServerError;
+ }
+ else
+ {
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// CPR helpers
+
+static cpr::Body
+AsCprBody(const CbObject& Obj)
+{
+ return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize());
+}
+
+static cpr::Body
+AsCprBody(const IoBuffer& Obj)
+{
+ return cpr::Body((const char*)Obj.GetData(), Obj.GetSize());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClient::Response
+ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload)
+{
+ // This ends up doing a memcpy, would be good to get rid of it by streaming results
+ // into buffer directly
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size());
+
+ if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
+ {
+ const HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed};
+}
+
+static HttpClient::Response
+CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {})
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
+ if (HttpResponse.error)
+ {
+ if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT &&
+ HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse);
+ }
+
+ // Client side failure code
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()),
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed,
+ .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
+ .ErrorMessage = HttpResponse.error.message}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
+ {
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
+ .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
+ .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
+ .ElapsedSeconds = HttpResponse.elapsed};
+ }
+ else
+ {
+ return ResponseWithPayload(
+ SessionId,
+ HttpResponse,
+ WorkResponseCode,
+ Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()));
+ }
+}
+
+static bool
+ShouldRetry(const cpr::Response& Response)
+{
+ switch (Response.error.code)
+ {
+ case cpr::ErrorCode::OK:
+ break;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch ((HttpResponseCode)Response.status_code)
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+};
+
+static bool
+ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+ IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size());
+
+ if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end())
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second);
+ if (!ExpectedContentSize.has_value())
+ {
+ Response.error =
+ cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second));
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Response.error = cpr::Error(
+ /*CURLE_READ_ERROR*/ 26,
+ fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second));
+ return false;
+ }
+ }
+
+ if (Response.status_code == (long)HttpResponseCode::PartialContent)
+ {
+ return true;
+ }
+
+ if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end())
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26,
+ fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString()));
+ return false;
+ }
+ }
+ }
+
+ if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end())
+ {
+ if (ContentType->second == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize))
+ {
+ return true;
+ }
+ else
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation");
+ return false;
+ }
+ }
+ if (ContentType->second == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error)));
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+static cpr::Response
+DoWithRetry(
+ std::string_view SessionId,
+ std::function<cpr::Response()>&& Func,
+ uint8_t RetryCount,
+ std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; })
+{
+ uint8_t Attempt = 0;
+ cpr::Response Result = Func();
+ while (Attempt < RetryCount)
+ {
+ if (!ShouldRetry(Result))
+ {
+ if (Result.error || !IsHttpSuccessCode(Result.status_code))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1);
+ Result = Func();
+ }
+ return Result;
+}
+
+static cpr::Response
+DoWithRetry(std::string_view SessionId,
+ std::function<cpr::Response()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile,
+ uint8_t RetryCount)
+{
+ uint8_t Attempt = 0;
+ cpr::Response Result = Func();
+ while (Attempt < RetryCount)
+ {
+ if (!ShouldRetry(Result))
+ {
+ if (Result.error || !IsHttpSuccessCode(Result.status_code))
+ {
+ break;
+ }
+ if (ValidatePayload(Result, PayloadFile))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1);
+ Result = Func();
+ }
+ return Result;
+}
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings)
+: HttpClientBase(BaseUri, Connectionsettings)
+{
+}
+
+CprHttpClient::~CprHttpClient()
+{
+ ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto CprSession : m_Sessions)
+ {
+ delete CprSession;
+ }
+ m_Sessions.clear();
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CprHttpClient::Session
+CprHttpClient::AllocSession(const std::string_view BaseUrl,
+ const std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken)
+{
+ ZEN_TRACE_CPU("CprHttpClient::AllocSession");
+ cpr::Session* CprSession = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ CprSession = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (CprSession == nullptr)
+ {
+ CprSession = new cpr::Session();
+ CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout);
+ CprSession->SetTimeout(ConnectionSettings.Timeout);
+ if (ConnectionSettings.AssumeHttp2)
+ {
+ CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE});
+ }
+ }
+
+ if (!AdditionalHeader->empty())
+ {
+ CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end()));
+ }
+ if (!SessionId.empty())
+ {
+ CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
+ }
+ if (AccessToken)
+ {
+ CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
+ }
+ if (!Parameters->empty())
+ {
+ cpr::Parameters Tmp;
+ for (auto It = Parameters->begin(); It != Parameters->end(); It++)
+ {
+ Tmp.Add({It->first, It->second});
+ }
+ CprSession->SetParameters(Tmp);
+ }
+ else
+ {
+ CprSession->SetParameters({});
+ }
+
+ ExtendableStringBuilder<128> UrlBuffer;
+ UrlBuffer << BaseUrl << ResourcePath;
+ CprSession->SetUrl(UrlBuffer.c_str());
+
+ return Session(this, CprSession);
+}
+
+void
+CprHttpClient::ReleaseSession(cpr::Session* CprSession)
+{
+ ZEN_TRACE_CPU("CprHttpClient::ReleaseSession");
+ CprSession->SetUrl({});
+ CprSession->SetHeader({});
+ CprSession->SetBody({});
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); });
+}
+
+CprHttpClient::Response
+CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::TransactPackage");
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++HttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}});
+ Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (FilterResponse.status_code == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ else
+ {
+ // This should be an error -- server asked to have something we can't find
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}});
+ Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (!IsHttpSuccessCode(FilterResponse.status_code))
+ {
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code)};
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
+
+ if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
+ {
+ HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CprHttpClient::Response
+CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())});
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Put");
+
+ return CommonResponse(m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri,
+ Url,
+ m_ConnectionSettings,
+ {{"Content-Length", "0"}},
+ Parameters,
+ m_SessionId,
+ GetAccessToken());
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Get");
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ return Sess.Get();
+ },
+ m_ConnectionSettings.RetryCount,
+ [](cpr::Response& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }));
+}
+
+CprHttpClient::Response
+CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Head");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ return Sess.Head();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Delete");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ return Sess.Delete();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostNoPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ uint64_t Offset = 0;
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+ auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, Payload.GetSize() - Offset);
+ Buffer.Read(buffer, size);
+ Offset += size;
+ return true;
+ };
+ return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ }
+ Sess->SetBody(AsCprBody(Payload));
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+
+ Sess->SetBody(AsCprBody(Payload));
+ Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)});
+ return Sess.Post();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CprHttpClient::Response
+CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+ auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) {
+ size = Reader.Read(buffer, size);
+ return true;
+ };
+ return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())});
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ uint64_t Offset = 0;
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+ auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, Payload.GetSize() - Offset);
+ Buffer.Read(buffer, size);
+ Offset += size;
+ return true;
+ };
+ return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ }
+ Sess->SetBody(AsCprBody(Payload));
+ return Sess.Put();
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Sess->UpdateHeader({HeaderContentType(ContentType)});
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+ auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) {
+ size = Reader.Read(buffer, size);
+ return true;
+ };
+ return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ },
+ m_ConnectionSettings.RetryCount));
+}
+
+CprHttpClient::Response
+CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CprHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+ cpr::Response Response = DoWithRetry(
+ m_SessionId,
+ [&]() {
+ auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> {
+ size_t DelimiterPos = header.find(':');
+ if (DelimiterPos != std::string::npos)
+ {
+ std::string Key = header.substr(0, DelimiterPos);
+ constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters);
+ Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters);
+
+ std::string Value = header.substr(DelimiterPos + 1);
+ Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters);
+ Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters);
+
+ return std::make_pair(Key, Value);
+ }
+ return std::make_pair(header, "");
+ };
+
+ auto DownloadCallback = [&](std::string data, intptr_t) {
+ if (PayloadFile)
+ {
+ ZEN_ASSERT(PayloadString.empty());
+ std::error_code Ec = PayloadFile->Write(data);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ return false;
+ }
+ }
+ else
+ {
+ PayloadString.append(data);
+ }
+ return true;
+ };
+
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ size_t RangeStartPos = RangeIt->second.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos);
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart =
+ ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength = RequestedRangeEnd.value() - 1;
+ }
+ }
+ }
+ }
+ }
+
+ cpr::Response Response;
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ std::pair<std::string, std::string> Header = GetHeader(header);
+ if (Header.first == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > 1024 * 1024)
+ {
+ PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ TempFolderPath.string(),
+ Ec.message());
+ PayloadFile.reset();
+ }
+ }
+ else
+ {
+ PayloadString.reserve(ContentLength.value());
+ }
+ }
+ }
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+ return 1;
+ };
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ }
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const cpr::Response& Response) -> bool {
+ if (Response.header.find("Content-Range") != Response.header.end())
+ {
+ return true;
+ }
+ if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end())
+ {
+ return It->second == "bytes"sv;
+ }
+ return false;
+ };
+
+ auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool {
+ if (ShouldRetry(Response))
+ {
+ return SupportsRanges(Response);
+ }
+ return false;
+ };
+
+ if (ShouldResume(Response))
+ {
+ auto It = Response.header.find("Content-Length");
+ if (It != Response.header.end())
+ {
+ std::vector<std::pair<std::string, std::string>> ReceivedHeaders;
+
+ auto HeaderCallback = [&](std::string header, intptr_t) {
+ std::pair<std::string, std::string> Header = GetHeader(header);
+ if (!Header.first.empty())
+ {
+ ReceivedHeaders.emplace_back(std::move(Header));
+ }
+
+ if (Header.first == "Content-Range"sv)
+ {
+ if (Header.second.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Header.second.find('-', 6);
+ if (RangeStartEnd != std::string::npos)
+ {
+ const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+ if (Start.value() == DownloadedSize)
+ {
+ return 1;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (PayloadFile)
+ {
+ PayloadFile->ResetWritePos(Start.value());
+ }
+ else
+ {
+ PayloadString = PayloadString.substr(0, Start.value());
+ }
+ return 1;
+ }
+ }
+ }
+ return 0;
+ }
+ return 1;
+ };
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ // If we didn't make any progress, abort
+ break;
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback});
+ for (const std::pair<std::string, std::string>& H : ReceivedHeaders)
+ {
+ Response.header.insert_or_assign(H.first, H.second);
+ }
+ ReceivedHeaders.clear();
+ } while (ShouldResume(Response));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Response.text = std::move(PayloadString);
+ }
+ return Response;
+ },
+ PayloadFile,
+ m_ConnectionSettings.RetryCount);
+
+ return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+}
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h
new file mode 100644
index 000000000..ed9d10c27
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcpr.h
@@ -0,0 +1,151 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httpclientcommon.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/cprutils.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/body.h>
+#include <cpr/session.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CprHttpClient : public HttpClientBase
+{
+public:
+ CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {});
+ ~CprHttpClient();
+
+ // HttpClientBase
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url,
+ CbPackage Package,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+private:
+ struct Session
+ {
+ Session(CprHttpClient* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {}
+ ~Session() { Outer->ReleaseSession(CprSession); }
+
+ inline cpr::Session* operator->() const { return CprSession; }
+ inline cpr::Response Get()
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Get");
+ cpr::Response Result = CprSession->Get();
+ ZEN_TRACE("GET {}", Result);
+ return Result;
+ }
+ inline cpr::Response Download(cpr::WriteCallback&& Write, std::optional<cpr::HeaderCallback>&& Header = {})
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Download");
+ if (Header)
+ {
+ CprSession->SetHeaderCallback(std::move(Header.value()));
+ }
+ cpr::Response Result = CprSession->Download(Write);
+ ZEN_TRACE("GET {}", Result);
+ CprSession->SetHeaderCallback({});
+ CprSession->SetWriteCallback({});
+ return Result;
+ }
+ inline cpr::Response Head()
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Head");
+ cpr::Response Result = CprSession->Head();
+ ZEN_TRACE("HEAD {}", Result);
+ return Result;
+ }
+ inline cpr::Response Put(std::optional<cpr::ReadCallback>&& Read = {})
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Put");
+ if (Read)
+ {
+ CprSession->SetReadCallback(std::move(Read.value()));
+ }
+ cpr::Response Result = CprSession->Put();
+ ZEN_TRACE("PUT {}", Result);
+ CprSession->SetReadCallback({});
+ return Result;
+ }
+ inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {})
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Post");
+ if (Read)
+ {
+ CprSession->SetReadCallback(std::move(Read.value()));
+ }
+ cpr::Response Result = CprSession->Post();
+ ZEN_TRACE("POST {}", Result);
+ CprSession->SetReadCallback({});
+ return Result;
+ }
+ inline cpr::Response Delete()
+ {
+ ZEN_TRACE_CPU("HttpClient::Impl::Delete");
+ cpr::Response Result = CprSession->Delete();
+ ZEN_TRACE("DELETE {}", Result);
+ return Result;
+ }
+
+ LoggerRef Log() { return Outer->Log(); }
+
+ private:
+ CprHttpClient* Outer;
+ cpr::Session* CprSession;
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+ };
+
+ Session AllocSession(const std::string_view BaseUrl,
+ const std::string_view Url,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ const std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken);
+
+ RwLock m_SessionLock;
+ std::vector<cpr::Session*> m_Sessions;
+
+ void ReleaseSession(cpr::Session*);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index a29a08a3c..c5c808c23 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -1,849 +1,383 @@
// Copyright Epic Games, Inc. All Rights Reserved.
+#include <zenhttp/formatters.h>
#include <zenhttp/httpclient.h>
#include <zenhttp/httpserver.h>
+#include <zenhttp/packageformat.h>
-#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
#include <zencore/compositebuffer.h>
#include <zencore/except.h>
#include <zencore/filesystem.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
+#include <zencore/memory/memory.h>
#include <zencore/session.h>
#include <zencore/sharedbuffer.h>
#include <zencore/stream.h>
-#include <zencore/testing.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
-#include <zenhttp/formatters.h>
-#include <zenutil/packageformat.h>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <cpr/cpr.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
-# include <fcntl.h>
-# include <sys/stat.h>
-# include <unistd.h>
-#endif
+#include "clients/httpclientcommon.h"
-static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif // ZEN_WITH_TESTS
namespace zen {
+extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings);
+
using namespace std::literals;
//////////////////////////////////////////////////////////////////////////
-//
-// CPR helpers
-cpr::Body
-AsCprBody(const CbObject& Obj)
-{
- return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize());
-}
-
-cpr::Body
-AsCprBody(const IoBuffer& Obj)
+HttpClientBase::HttpClientBase(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings)
+: m_Log(zen::logging::Get(ConnectionSettings.LogCategory))
+, m_BaseUri(BaseUri)
+, m_ConnectionSettings(ConnectionSettings)
{
- return cpr::Body((const char*)Obj.GetData(), Obj.GetSize());
+ if (ConnectionSettings.SessionId == Oid::Zero)
+ {
+ m_SessionId = GetSessionIdString();
+ }
+ else
+ {
+ m_SessionId = ConnectionSettings.SessionId.ToString();
+ }
}
-cpr::Body
-AsCprBody(const CompositeBuffer& Buffers)
+HttpClientBase::~HttpClientBase()
{
- SharedBuffer Buffer = Buffers.Flatten();
-
- // This is super inefficient, should be fixed
- std::string String{(const char*)Buffer.GetData(), Buffer.GetSize()};
- return cpr::Body{std::move(String)};
}
-//////////////////////////////////////////////////////////////////////////
-
-HttpClient::Response
-ResponseWithPayload(cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload)
+bool
+HttpClientBase::Authenticate()
{
- // This ends up doing a memcpy, would be good to get rid of it by streaming results
- // into buffer directly
- IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size());
-
- if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end())
- {
- const HttpContentType ContentType = ParseContentType(It->second);
-
- ResponseBuffer.SetContentType(ContentType);
- }
-
- if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ ZEN_TRACE_CPU("HttpClientBase::Authenticate");
+ std::optional<HttpClientAccessToken> Token = GetAccessToken();
+ if (!Token)
{
- ZEN_WARN("HttpClient request failed: {}", HttpResponse);
+ return false;
}
-
- return HttpClient::Response{.StatusCode = WorkResponseCode,
- .ResponsePayload = std::move(ResponseBuffer),
- .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
- .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
- .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ return Token->IsValid();
}
-HttpClient::Response
-CommonResponse(cpr::Response&& HttpResponse, IoBuffer&& Payload = {})
+const std::optional<HttpClientAccessToken>
+HttpClientBase::GetAccessToken()
{
- const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code);
- if (HttpResponse.error)
+ ZEN_TRACE_CPU("HttpClientBase::GetAccessToken");
+ if (!m_ConnectionSettings.AccessTokenProvider.has_value())
{
- ZEN_WARN("HttpClient client error: {}", HttpResponse);
-
- // Client side failure code
- return HttpClient::Response{
- .StatusCode = WorkResponseCode,
- .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()),
- .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
- .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
- .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed,
- .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
- .ErrorMessage = HttpResponse.error.message}};
+ return {};
}
-
- if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
{
- return HttpClient::Response{.StatusCode = WorkResponseCode,
- .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()),
- .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
- .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
- .ElapsedSeconds = HttpResponse.elapsed};
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (m_CachedAccessToken.IsValid())
+ {
+ return m_CachedAccessToken;
+ }
}
- else
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (m_CachedAccessToken.IsValid())
{
- return ResponseWithPayload(HttpResponse, WorkResponseCode, std::move(Payload));
+ return m_CachedAccessToken;
}
+ m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()();
+ return m_CachedAccessToken;
}
//////////////////////////////////////////////////////////////////////////
-struct HttpClient::Impl : public RefCounted
+CbObject
+HttpClient::Response::AsObject() const
{
- Impl(LoggerRef Log);
- ~Impl();
-
- // Session allocation
-
- struct Session
+ if (ResponsePayload)
{
- Session(Impl* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {}
- ~Session() { Outer->ReleaseSession(CprSession); }
-
- inline cpr::Session* operator->() const { return CprSession; }
- inline cpr::Response Get()
- {
- cpr::Response Result = CprSession->Get();
- ZEN_TRACE("GET {}", Result);
- return Result;
- }
- inline cpr::Response Download(cpr::WriteCallback&& write)
- {
- cpr::Response Result = CprSession->Download(write);
- ZEN_TRACE("GET {}", Result);
- return Result;
- }
- inline cpr::Response Head()
- {
- cpr::Response Result = CprSession->Head();
- ZEN_TRACE("HEAD {}", Result);
- return Result;
- }
- inline cpr::Response Put()
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(IoBuffer(ResponsePayload), ValidationError);
+ ValidationError == CbValidateError::None)
{
- cpr::Response Result = CprSession->Put();
- ZEN_TRACE("PUT {}", Result);
- return Result;
+ return ResponseObject;
}
- inline cpr::Response Post()
- {
- cpr::Response Result = CprSession->Post();
- ZEN_TRACE("POST {}", Result);
- return Result;
- }
- inline cpr::Response Delete()
- {
- cpr::Response Result = CprSession->Delete();
- ZEN_TRACE("DELETE {}", Result);
- return Result;
- }
-
- LoggerRef Logger() { return Outer->Logger(); }
-
- private:
- Impl* Outer;
- cpr::Session* CprSession;
-
- Session(Session&&) = delete;
- Session& operator=(Session&&) = delete;
- };
-
- Session AllocSession(const std::string_view BaseUrl,
- const std::string_view Url,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters);
-
- LoggerRef Logger() { return m_Log; }
-
-private:
- LoggerRef m_Log;
- RwLock m_SessionLock;
- std::vector<cpr::Session*> m_Sessions;
-
- void ReleaseSession(cpr::Session*);
-};
+ }
-HttpClient::Impl::Impl(LoggerRef Log) : m_Log(Log)
-{
+ return {};
}
-HttpClient::Impl::~Impl()
+CbPackage
+HttpClient::Response::AsPackage() const
{
- m_SessionLock.WithExclusiveLock([&] {
- for (auto CprSession : m_Sessions)
- {
- delete CprSession;
- }
- m_Sessions.clear();
- });
+ // TODO: sanity checks and error handling
+ if (ResponsePayload)
+ {
+ return ParsePackageMessage(ResponsePayload);
+ }
+
+ return {};
}
-HttpClient::Impl::Session
-HttpClient::Impl::AllocSession(const std::string_view BaseUrl,
- const std::string_view ResourcePath,
- const HttpClientSettings& ConnectionSettings,
- const KeyValueMap& AdditionalHeader,
- const KeyValueMap& Parameters)
+std::string_view
+HttpClient::Response::AsText() const
{
- bool IsNew = false;
- cpr::Session* CprSession = nullptr;
- m_SessionLock.WithExclusiveLock([&] {
- if (m_Sessions.empty())
- {
- CprSession = new cpr::Session();
- IsNew = true;
- }
- else
- {
- CprSession = m_Sessions.back();
- m_Sessions.pop_back();
- }
- });
-
- if (IsNew)
+ if (ResponsePayload)
{
- CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout);
- CprSession->SetTimeout(ConnectionSettings.Timeout);
- if (ConnectionSettings.AssumeHttp2)
- {
- CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE});
- }
+ return std::string_view(reinterpret_cast<const char*>(ResponsePayload.GetData()), ResponsePayload.GetSize());
}
- if (!AdditionalHeader->empty())
- {
- CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end()));
- }
- else
- {
- CprSession->SetHeader({});
- }
- if (!Parameters->empty())
- {
- cpr::Parameters Tmp;
- for (auto It = Parameters->begin(); It != Parameters->end(); It++)
- {
- Tmp.Add({It->first, It->second});
- }
- CprSession->SetParameters(Tmp);
- }
- else
+ return {};
+}
+
+std::string
+HttpClient::Response::ToText() const
+{
+ if (!ResponsePayload)
+ return {};
+
+ switch (ResponsePayload.GetContentType())
{
- CprSession->SetParameters({});
- }
+ case ZenContentType::kCbObject:
+ {
+ zen::ExtendableStringBuilder<1024> ObjStr;
+ zen::CbObject Object{SharedBuffer(ResponsePayload)};
+ zen::CompactBinaryToJson(Object, ObjStr);
+ return ObjStr.ToString();
+ }
+ break;
- ExtendableStringBuilder<128> UrlBuffer;
- UrlBuffer << BaseUrl << ResourcePath;
- CprSession->SetUrl(UrlBuffer.c_str());
+ case ZenContentType::kCSS:
+ case ZenContentType::kHTML:
+ case ZenContentType::kJavaScript:
+ case ZenContentType::kJSON:
+ case ZenContentType::kText:
+ case ZenContentType::kYAML:
+ return std::string{AsText()};
- return Session(this, CprSession);
+ default:
+ return "<unhandled content format>";
+ }
}
-void
-HttpClient::Impl::ReleaseSession(cpr::Session* CprSession)
+bool
+HttpClient::Response::IsSuccess() const noexcept
{
- CprSession->SetUrl({});
- CprSession->SetHeader({});
- CprSession->SetBody({});
- m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); });
+ return !Error && IsHttpSuccessCode(StatusCode);
}
-namespace detail {
-
- static std::atomic_uint32_t TempFileBaseIndex;
-
-} // namespace detail
-
-class TempPayloadFile
+std::string
+HttpClient::Response::ErrorMessage(std::string_view Prefix) const
{
-public:
- TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {}
- ~TempPayloadFile()
+ if (Error.has_value())
{
- try
- {
- if (m_FileHandle)
- {
-#if ZEN_PLATFORM_WINDOWS
- // Mark file for deletion when final handle is closed
- FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
-
- SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
- BOOL Success = CloseHandle(m_FileHandle);
-#else
- std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle);
- unlink(FilePath.c_str());
- int Fd = int(uintptr_t(m_FileHandle));
- bool Success = (close(Fd) == 0);
-#endif
- if (!Success)
- {
- ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString());
- }
-
- m_FileHandle = nullptr;
- }
- }
- catch (std::exception& Ex)
- {
- ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what());
- }
+ return fmt::format("{}: {}", Prefix, Error->ErrorMessage);
}
-
- std::error_code Open(const std::filesystem::path& TempFolderPath)
+ else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode)
{
- ZEN_ASSERT(m_FileHandle == nullptr);
-
- std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) |
- detail::TempFileBaseIndex.fetch_add(1);
-
- std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex);
-#if ZEN_PLATFORM_WINDOWS
- LPCWSTR lpFileName = FileName.c_str();
- const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE);
- const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
- LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr;
- const DWORD dwCreationDisposition = CREATE_ALWAYS;
- const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL;
- const HANDLE hTemplateFile = nullptr;
- const HANDLE FileHandle = CreateFile(lpFileName,
- dwDesiredAccess,
- dwShareMode,
- lpSecurityAttributes,
- dwCreationDisposition,
- dwFlagsAndAttributes,
- hTemplateFile);
-
- if (FileHandle == INVALID_HANDLE_VALUE)
- {
- return MakeErrorCodeFromLastError();
- }
-#else // ZEN_PLATFORM_WINDOWS
- int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC;
- int Fd = open(FileName.c_str(), OpenFlags, 0666);
- if (Fd < 0)
- {
- return MakeErrorCodeFromLastError();
- }
- fchmod(Fd, 0666);
-
- void* FileHandle = (void*)(uintptr_t(Fd));
-#endif // ZEN_PLATFORM_WINDOWS
- m_FileHandle = FileHandle;
-
- return {};
+ std::string TextResponse = ToText();
+ return fmt::format("{}{}HTTP error {} {}{}",
+ Prefix,
+ Prefix.empty() ? ""sv : ": "sv,
+ (int)StatusCode,
+ zen::ToString(StatusCode),
+ TextResponse.empty() ? ""sv : fmt::format(" ({})", TextResponse));
}
-
- std::error_code Write(std::string_view DataString)
+ else
{
- ZEN_ASSERT(m_FileHandle != nullptr);
- const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
- const void* Data = DataString.data();
- std::size_t Size = DataString.size();
-
- while (Size)
- {
- const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize);
- uint64_t NumberOfBytesWritten = 0;
-#if ZEN_PLATFORM_WINDOWS
- OVERLAPPED Ovl{};
-
- Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu);
- Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32);
-
- DWORD dwNumberOfBytesWritten = 0;
-
- BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl);
- if (Success)
- {
- NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten);
- }
-#else
- static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
- int Fd = int(uintptr_t(m_FileHandle));
- int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset);
- bool Success = (BytesWritten > 0);
- if (Success)
- {
- NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten);
- }
-#endif
-
- if (!Success)
- {
- return MakeErrorCodeFromLastError();
- }
-
- Size -= NumberOfBytesWritten;
- m_WriteOffset += NumberOfBytesWritten;
- Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten;
- }
- return {};
+ return fmt::format("{}{}unknown error", Prefix, Prefix.empty() ? ""sv : ": "sv);
}
+}
- IoBuffer DetachToIoBuffer()
+void
+HttpClient::Response::ThrowError(std::string_view ErrorPrefix)
+{
+ if (!IsSuccess())
{
- ZEN_ASSERT(m_FileHandle != nullptr);
- void* FileHandle = m_FileHandle;
- IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true);
- Buffer.SetDeleteOnClose(true);
- m_FileHandle = 0;
- m_WriteOffset = 0;
- return Buffer;
+ throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode);
}
-
-private:
- void* m_FileHandle;
- std::uint64_t m_WriteOffset;
-};
+}
//////////////////////////////////////////////////////////////////////////
-HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings)
-: m_Log(zen::logging::Get(Connectionsettings.LogCategory))
-, m_BaseUri(BaseUri)
-, m_ConnectionSettings(Connectionsettings)
-, m_Impl(new Impl(m_Log))
+HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings)
+: m_BaseUri(BaseUri)
+, m_ConnectionSettings(ConnectionSettings)
{
- StringBuilder<32> SessionId;
- GetSessionId().ToString(SessionId);
- m_SessionId = SessionId;
+ m_SessionId = GetSessionIdString();
+
+ m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings);
}
HttpClient::~HttpClient()
{
+ delete m_Inner;
}
-HttpClient::Response
-HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+void
+HttpClient::SetSessionId(const Oid& SessionId)
{
- ZEN_TRACE_CPU("HttpClient::TransactPackage");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- // First, list of offered chunks for filtering on the server end
-
- std::vector<IoHash> AttachmentsToSend;
- std::span<const CbAttachment> Attachments = Package.GetAttachments();
-
- const uint32_t RequestId = ++HttpClientRequestIdCounter;
- auto RequestIdString = fmt::to_string(RequestId);
-
- if (Attachments.empty() == false)
- {
- CbObjectWriter Writer;
- Writer.BeginArray("offer");
-
- for (const CbAttachment& Attachment : Attachments)
- {
- Writer.AddHash(Attachment.GetHash());
- }
-
- Writer.EndArray();
-
- BinaryWriter MemWriter;
- Writer.Save(MemWriter);
-
- Sess->UpdateHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
- Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
-
- cpr::Response FilterResponse = Sess.Post();
-
- if (FilterResponse.status_code == 200)
- {
- IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
- CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer);
-
- for (CbFieldView& Entry : ResponseObject["need"])
- {
- ZEN_ASSERT(Entry.IsHash());
- AttachmentsToSend.push_back(Entry.AsHash());
- }
- }
- }
-
- // Prepare package for send
-
- CbPackage SendPackage;
- SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
-
- for (const IoHash& AttachmentCid : AttachmentsToSend)
- {
- const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
-
- if (Attachment)
- {
- SendPackage.AddAttachment(*Attachment);
- }
- else
- {
- // This should be an error -- server asked to have something we can't find
- }
- }
-
- // Transmit package payload
-
- CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
- SharedBuffer FlatMessage = Message.Flatten();
-
- Sess->UpdateHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
- Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
-
- cpr::Response FilterResponse = Sess.Post();
-
- if (!IsHttpSuccessCode(FilterResponse.status_code))
+ if (SessionId == Oid::Zero)
{
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code)};
+ m_SessionId = GetSessionIdString();
}
-
- IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
-
- if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
+ else
{
- HttpContentType ContentType = ParseContentType(It->second);
-
- ResponseBuffer.SetContentType(ContentType);
+ m_SessionId = SessionId.ToString();
}
-
- return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
}
-//////////////////////////////////////////////////////////////////////////
-//
-// Standard HTTP verbs
-//
-
HttpClient::Response
-HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::Put");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
- Sess->SetBody(AsCprBody(Payload));
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}});
-
- return CommonResponse(Sess.Put());
+ return m_Inner->Put(Url, Payload, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+HttpClient::Put(std::string_view Url, const HttpClient::KeyValueMap& Parameters)
{
- ZEN_TRACE_CPU("HttpClient::Get");
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters);
-
- return CommonResponse(Sess.Get());
+ return m_Inner->Put(Url, Parameters);
}
HttpClient::Response
-HttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+HttpClient::Get(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters)
{
- ZEN_TRACE_CPU("HttpClient::Head");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- return CommonResponse(Sess.Head());
+ return m_Inner->Get(Url, AdditionalHeader, Parameters);
}
HttpClient::Response
-HttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+HttpClient::Head(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::Delete");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- return CommonResponse(Sess.Delete());
+ return m_Inner->Head(Url, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+HttpClient::Delete(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::PostNoPayload");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters);
-
- return CommonResponse(Sess.Post());
+ return m_Inner->Delete(Url, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters)
{
- ZEN_TRACE_CPU("HttpClient::PostWithPayload");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- Sess->SetBody(AsCprBody(Payload));
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}});
-
- return CommonResponse(Sess.Post());
+ return m_Inner->Post(Url, AdditionalHeader, Parameters);
}
HttpClient::Response
-HttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::PostObjectPayload");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- Sess->SetBody(AsCprBody(Payload));
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ZenContentType::kCbObject))}});
-
- return CommonResponse(Sess.Post());
+ return m_Inner->Post(Url, Payload, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::PostPackage");
-
- CompositeBuffer Message = zen::FormatPackageMessageBuffer(Pkg);
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
- Sess->SetBody(AsCprBody(Message));
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ZenContentType::kCbPackage))}});
-
- return CommonResponse(Sess.Post());
+ return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::Upload");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}});
-
- uint64_t Offset = 0;
- if (Payload.IsWholeFile())
- {
- auto ReadCallback = [&Payload, &Offset](char* buffer, size_t& size, intptr_t) {
- size = Min<size_t>(size, Payload.GetSize() - Offset);
- IoBuffer PayloadRange = IoBuffer(Payload, Offset, size);
- MutableMemoryView Data(buffer, size);
- Data.CopyFrom(PayloadRange.GetView());
- Offset += size;
- return true;
- };
- Sess->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
- }
- else
- {
- Sess->SetBody(AsCprBody(Payload));
- }
- return CommonResponse(Sess.Put());
+ return m_Inner->Post(Url, Payload, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url, CbPackage Payload, const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::Upload");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
- Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ContentType))}});
-
- uint64_t SizeLeft = Payload.GetSize();
- CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
- auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
- size = Min<size_t>(size, SizeLeft);
- MutableMemoryView Data(buffer, size);
- Payload.CopyTo(Data, BufferIt);
- SizeLeft -= size;
- return true;
- };
- Sess->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
-
- return CommonResponse(Sess.Put());
+ return m_Inner->Post(Url, Payload, AdditionalHeader);
}
HttpClient::Response
-HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+HttpClient::Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const HttpClient::KeyValueMap& AdditionalHeader)
{
- ZEN_TRACE_CPU("HttpClient::Download");
-
- Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {});
-
- std::string PayloadString;
- std::unique_ptr<TempPayloadFile> PayloadFile;
-
- cpr::Response Response = Sess.Download(cpr::WriteCallback{[&](std::string data, intptr_t) {
- if (!PayloadFile && (PayloadString.length() + data.length()) > (1024 * 1024))
- {
- PayloadFile = std::make_unique<TempPayloadFile>();
- std::error_code Ec = PayloadFile->Open(TempFolderPath);
- if (Ec)
- {
- ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", TempFolderPath.string(), Ec.message());
- return false;
- }
- PayloadFile->Write(PayloadString);
- PayloadString.clear();
- }
- if (PayloadFile)
- {
- std::error_code Ec = PayloadFile->Write(data);
- if (Ec)
- {
- ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
- TempFolderPath.string(),
- Ec.message());
- return false;
- }
- }
- else
- {
- PayloadString.append(data);
- }
- return true;
- }});
-
- if (!PayloadString.empty())
- {
- Response.text = std::move(PayloadString);
- }
-
- return CommonResponse(std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{});
+ return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader);
}
-//////////////////////////////////////////////////////////////////////////
-
-CbObject
-HttpClient::Response::AsObject() const
+HttpClient::Response
+HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader)
{
- // TODO: sanity check the payload format etc
-
- if (ResponsePayload)
- {
- return LoadCompactBinaryObject(ResponsePayload);
- }
-
- return {};
+ return m_Inner->Upload(Url, Payload, AdditionalHeader);
}
-CbPackage
-HttpClient::Response::AsPackage() const
+HttpClient::Response
+HttpClient::Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const HttpClient::KeyValueMap& AdditionalHeader)
{
- // TODO: sanity checks and error handling
- if (ResponsePayload)
- {
- return ParsePackageMessage(ResponsePayload);
- }
-
- return {};
+ return m_Inner->Upload(Url, Payload, ContentType, AdditionalHeader);
}
-std::string_view
-HttpClient::Response::AsText() const
+HttpClient::Response
+HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const HttpClient::KeyValueMap& AdditionalHeader)
{
- if (ResponsePayload)
- {
- return std::string_view(reinterpret_cast<const char*>(ResponsePayload.GetData()), ResponsePayload.GetSize());
- }
-
- return {};
+ return m_Inner->Download(Url, TempFolderPath, AdditionalHeader);
}
-std::string
-HttpClient::Response::ToText() const
+HttpClient::Response
+HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const HttpClient::KeyValueMap& AdditionalHeader)
{
- if (!ResponsePayload)
- return {};
-
- switch (ResponsePayload.GetContentType())
- {
- case ZenContentType::kCbObject:
- {
- zen::ExtendableStringBuilder<1024> ObjStr;
- zen::CbObject Object{SharedBuffer(ResponsePayload)};
- zen::CompactBinaryToJson(Object, ObjStr);
- return ObjStr.ToString();
- }
- break;
-
- case ZenContentType::kCSS:
- case ZenContentType::kHTML:
- case ZenContentType::kJavaScript:
- case ZenContentType::kJSON:
- case ZenContentType::kText:
- case ZenContentType::kYAML:
- return std::string{AsText()};
-
- default:
- return "<unhandled content format>";
- }
+ return m_Inner->TransactPackage(Url, Package, AdditionalHeader);
}
bool
-HttpClient::Response::IsSuccess() const noexcept
+HttpClient::Authenticate()
{
- return !Error && IsHttpSuccessCode(StatusCode);
+ return m_Inner->Authenticate();
}
-std::string
-HttpClient::Response::ErrorMessage(std::string_view Prefix) const
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("responseformat")
{
- if (Error.has_value())
- {
- return fmt::format("{}: {}", Prefix, Error->ErrorMessage);
- }
- else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode)
+ using namespace std::literals;
+
+ SUBCASE("identity")
{
- return fmt::format("{}: HTTP error {} {} ({})", Prefix, (int)StatusCode, zen::ToString(StatusCode), AsText());
+ BodyLogFormatter _{"abcd"};
+ CHECK_EQ(_.GetText(), "abcd"sv);
}
- else
+
+ SUBCASE("very long")
{
- return fmt::format("{}: {}", Prefix, "unknown error");
+ std::string_view LongView =
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz";
+
+ BodyLogFormatter _{LongView};
+
+ CHECK(_.GetText().size() < LongView.size());
+ CHECK(_.GetText().starts_with("[truncated"sv));
}
-}
-void
-HttpClient::Response::ThrowError(std::string_view ErrorPrefix)
-{
- if (!IsSuccess())
+ SUBCASE("invalid text")
{
- throw std::runtime_error(ErrorMessage(ErrorPrefix));
- }
-}
+ std::string_view BadText = "totobaba\xff\xfe";
-//////////////////////////////////////////////////////////////////////////
+ BodyLogFormatter _{BadText};
-#if ZEN_WITH_TESTS
+ CHECK_EQ(_.GetText(), "totobaba");
+ }
+}
TEST_CASE("httpclient")
{
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
new file mode 100644
index 000000000..8754c57d6
--- /dev/null
+++ b/src/zenhttp/httpclientauth.cpp
@@ -0,0 +1,212 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclientauth.h>
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/process.h>
+#include <zencore/scopeguard.h>
+#include <zencore/timer.h>
+#include <zencore/uid.h>
+#include <zenhttp/auth/authmgr.h>
+#include <zenhttp/httpclient.h>
+
+#include <ctime>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# define timegm _mkgmtime
+#endif // ZEN_PLATFORM_WINDOWS
+
+namespace zen { namespace httpclientauth {
+
+ using namespace std::literals;
+
+ std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token)
+ {
+ return [Token]() { return Token; };
+ }
+
+ std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token)
+ {
+ return CreateFromStaticToken(
+ HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = HttpClientAccessToken::TimePoint::max()});
+ }
+
+ std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params)
+ {
+ OAuthClientCredentialsParams OAuthParams(Params);
+ return [OAuthParams]() {
+ using namespace std::chrono;
+
+ std::string Body = fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}"sv,
+ OAuthParams.ClientId,
+ OAuthParams.ClientSecret);
+
+ HttpClient Http{OAuthParams.Url};
+
+ IoBuffer Payload{IoBuffer::Wrap, Body.data(), Body.size()};
+
+ // TODO: ensure this gets the right Content-Type passed along
+
+ HttpClient::Response Response = Http.Post("", Payload, {{"Content-Type", "application/x-www-form-urlencoded"}});
+
+ if (!Response || Response.StatusCode != HttpResponseCode::OK)
+ {
+ ZEN_WARN("Failed fetching OAuth access token {}. Reason: '{}'", OAuthParams.Url, Response.ErrorMessage(""));
+ return HttpClientAccessToken{};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ ZEN_WARN("Unable to parse OAuth json response from {}. Reason: '{}'", OAuthParams.Url, JsonError);
+ return HttpClientAccessToken{};
+ }
+
+ std::string Token = Json["access_token"].string_value();
+ int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value());
+ HttpClientAccessToken::TimePoint ExpireTime = HttpClientAccessToken::Clock::now() + seconds(ExpiresInSeconds);
+
+ return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
+ };
+ }
+
+ std::function<HttpClientAccessToken()> CreateFromOpenIdProvider(AuthMgr& AuthManager, std::string_view OpenIdProvider)
+ {
+ return [&AuthManager = AuthManager, OpenIdProvider = std::string(OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider);
+ return HttpClientAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ };
+ }
+
+ std::function<HttpClientAccessToken()> CreateFromDefaultOpenIdProvider(AuthMgr& AuthManager)
+ {
+ return CreateFromOpenIdProvider(AuthManager, "Default"sv);
+ }
+
+ static HttpClientAccessToken GetOidcTokenFromExe(const std::filesystem::path& OidcExecutablePath,
+ std::string_view CloudHost,
+ bool Unattended,
+ bool Quiet,
+ bool Hidden)
+ {
+ Stopwatch Timer;
+
+ CreateProcOptions ProcOptions;
+ if (Quiet)
+ {
+ ProcOptions.StdoutFile = std::filesystem::temp_directory_path() / fmt::format(".zen-auth-output-{}", Oid::NewOid());
+ }
+ if (Hidden)
+ {
+ ProcOptions.Flags |= CreateProcOptions::Flag_NoConsole;
+ }
+
+ const std::filesystem::path AuthTokenPath(std::filesystem::temp_directory_path() / fmt::format(".zen-auth-{}", Oid::NewOid()));
+ auto _ = MakeGuard([AuthTokenPath, &ProcOptions]() {
+ RemoveFile(AuthTokenPath);
+ if (!ProcOptions.StdoutFile.empty())
+ {
+ RemoveFile(ProcOptions.StdoutFile);
+ }
+ });
+
+ const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}",
+ OidcExecutablePath,
+ CloudHost,
+ AuthTokenPath,
+ Unattended ? "true"sv : "false"sv);
+ ZEN_DEBUG("Running: {}", ProcArgs);
+ ProcessHandle Proc;
+ Proc.Initialize(CreateProc(OidcExecutablePath, ProcArgs, ProcOptions));
+ if (!Proc.IsValid())
+ {
+ throw std::runtime_error(fmt::format("failed to launch '{}'", OidcExecutablePath));
+ }
+
+ int ExitCode = Proc.WaitExitCode();
+
+ auto EndTime = std::chrono::system_clock::now();
+
+ if (ExitCode == 0)
+ {
+ IoBuffer Body = IoBufferBuilder::MakeFromFile(AuthTokenPath);
+ std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize());
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(JsonText, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ ZEN_WARN("Unable to parse Oidcs json response from {}. Reason: '{}'", AuthTokenPath, JsonError);
+ return HttpClientAccessToken{};
+ }
+ std::string Token = Json["Token"].string_value();
+ std::string ExpiresAtUTCString = Json["ExpiresAtUtc"].string_value();
+ ZEN_ASSERT(!ExpiresAtUTCString.empty());
+
+ int Year = 0;
+ int Month = 0;
+ int Day = 0;
+ int Hour = 0;
+ int Minute = 0;
+ int Second = 0;
+ int Millisecond = 0;
+ sscanf(ExpiresAtUTCString.c_str(), "%d-%d-%dT%d:%d:%d.%dZ", &Year, &Month, &Day, &Hour, &Minute, &Second, &Millisecond);
+
+ std::tm Time = {
+ Second,
+ Minute,
+ Hour,
+ Day,
+ Month - 1,
+ Year - 1900,
+ };
+
+ time_t UTCTime = timegm(&Time);
+ HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime);
+ ExpireTime += std::chrono::microseconds(Millisecond);
+
+ return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime};
+ }
+ else
+ {
+ ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode);
+ }
+ return HttpClientAccessToken{};
+ }
+
+ std::optional<std::function<HttpClientAccessToken()>> CreateFromOidcTokenExecutable(const std::filesystem::path& OidcExecutablePath,
+ std::string_view CloudHost,
+ bool Quiet,
+ bool Unattended,
+ bool Hidden)
+ {
+ HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ if (InitialToken.IsValid())
+ {
+ return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath),
+ CloudHost = std::string(CloudHost),
+ Quiet,
+ Hidden,
+ InitialToken]() mutable {
+ if (InitialToken.IsValid())
+ {
+ HttpClientAccessToken Result = InitialToken;
+ InitialToken = {};
+ return Result;
+ }
+ return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, /* Unattended */ true, Quiet, Hidden);
+ };
+ }
+ return {};
+ }
+
+}} // namespace zen::httpclientauth
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index 3270855ad..2c063d646 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -18,19 +18,22 @@
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
#include <zencore/iobuffer.h>
#include <zencore/logging.h>
#include <zencore/stream.h>
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
-#include <zenutil/packageformat.h>
+#include <zenhttp/packageformat.h>
#include <charconv>
#include <mutex>
#include <span>
#include <string_view>
+#include <EASTL/fixed_vector.h>
+
namespace zen {
using namespace std::literals;
@@ -94,6 +97,7 @@ MapContentTypeToString(HttpContentType ContentType)
static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv);
static constinit uint32_t HashJson = HashStringDjb2("json"sv);
static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv);
+static constinit uint32_t HashApplicationProblemJson = HashStringDjb2("application/problem+json"sv);
static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv);
static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv);
static constinit uint32_t HashText = HashStringDjb2("text/plain"sv);
@@ -132,6 +136,7 @@ struct HashedTypeEntry
{HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer},
{HashJson, HttpContentType::kJSON},
{HashApplicationJson, HttpContentType::kJSON},
+ {HashApplicationProblemJson, HttpContentType::kJSON},
{HashYaml, HttpContentType::kYAML},
{HashTextYaml, HttpContentType::kYAML},
{HashText, HttpContentType::kText},
@@ -156,7 +161,14 @@ ParseContentTypeImpl(const std::string_view& ContentTypeString)
{
if (!ContentTypeString.empty())
{
- const uint32_t CtHash = HashStringDjb2(ContentTypeString);
+ size_t ContentEnd = ContentTypeString.find(';');
+ if (ContentEnd == std::string_view::npos)
+ {
+ ContentEnd = ContentTypeString.length();
+ }
+ std::string_view ContentString(ContentTypeString.substr(0, ContentEnd));
+
+ const uint32_t CtHash = HashStringDjb2(ContentString);
if (auto It = std::lower_bound(std::begin(TypeHashTable),
std::end(TypeHashTable),
@@ -468,6 +480,11 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data)
ExtendableStringBuilder<1024> Sb;
WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView());
}
+ else if (m_AcceptType == HttpContentType::kYAML)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kYAML, Data.ToYaml(Sb).ToView());
+ }
else
{
SharedBuffer Buf = Data.GetBuffer();
@@ -484,6 +501,11 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array)
ExtendableStringBuilder<1024> Sb;
WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView());
}
+ else if (m_AcceptType == HttpContentType::kYAML)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kYAML, Array.ToYaml(Sb).ToView());
+ }
else
{
SharedBuffer Buf = Array.GetBuffer();
@@ -510,7 +532,7 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType
{
std::span<const SharedBuffer> Segments = Payload.GetSegments();
- std::vector<IoBuffer> Buffers;
+ eastl::fixed_vector<IoBuffer, 64> Buffers;
Buffers.reserve(Segments.size());
for (auto& Segment : Segments)
@@ -518,7 +540,41 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType
Buffers.push_back(Segment.AsIoBuffer());
}
- WriteResponse(ResponseCode, ContentType, Buffers);
+ WriteResponse(ResponseCode, ContentType, std::span<IoBuffer>(begin(Buffers), end(Buffers)));
+}
+
+std::string
+HttpServerRequest::Decode(std::string_view PercentEncodedString)
+{
+ size_t Length = PercentEncodedString.length();
+ std::string Decoded;
+ Decoded.reserve(Length);
+ size_t Offset = 0;
+ while (Offset < Length)
+ {
+ char C = PercentEncodedString[Offset];
+ if (C == '%' && (Offset <= (Length - 3)))
+ {
+ std::string_view CharHash(&PercentEncodedString[Offset + 1], 2);
+ uint8_t DecodedChar = 0;
+ if (ParseHexBytes(CharHash, &DecodedChar))
+ {
+ Decoded.push_back((char)DecodedChar);
+ Offset += 3;
+ }
+ else
+ {
+ Decoded.push_back(C);
+ Offset++;
+ }
+ }
+ else
+ {
+ Decoded.push_back(C);
+ Offset++;
+ }
+ }
+ return Decoded;
}
HttpServerRequest::QueryParams
@@ -610,9 +666,13 @@ HttpServerRequest::ReadPayloadObject()
}
return CbObject();
}
- return LoadCompactBinaryObject(std::move(Payload));
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(Payload), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ return ResponseObject;
+ }
}
-
return {};
}
@@ -732,120 +792,131 @@ HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcA
//////////////////////////////////////////////////////////////////////////
-enum class HttpServerClass
-{
- kHttpAsio,
- kHttpSys,
- kHttpPlugin,
- kHttpMulti,
- kHttpNull
-};
-
Ref<HttpServer>
-CreateHttpServerClass(HttpServerClass Class, const HttpServerConfig& Config)
+CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig& Config)
{
- switch (Class)
+ if (ServerClass == "asio"sv)
{
- default:
- case HttpServerClass::kHttpAsio:
- ZEN_INFO("using asio HTTP server implementation");
- return CreateHttpAsioServer(Config.ForceLoopback, Config.ThreadCount);
-
- case HttpServerClass::kHttpMulti:
- {
- ZEN_INFO("using multi HTTP server implementation");
- Ref<HttpMultiServer> Server{new HttpMultiServer()};
-
- // This is hardcoded for now, but should be configurable in the future
- Server->AddServer(CreateHttpServerClass(HttpServerClass::kHttpSys, Config));
- Server->AddServer(CreateHttpServerClass(HttpServerClass::kHttpPlugin, Config));
+ ZEN_INFO("using asio HTTP server implementation")
+ return CreateHttpAsioServer(Config.ForceLoopback, Config.ThreadCount);
+ }
+#if ZEN_WITH_HTTPSYS
+ else if (ServerClass == "httpsys"sv)
+ {
+ ZEN_INFO("using http.sys server implementation")
+ return Ref<HttpServer>(CreateHttpSysServer({.ThreadCount = Config.ThreadCount,
+ .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount,
+ .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled,
+ .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled,
+ .IsDedicatedServer = Config.IsDedicatedServer,
+ .ForceLoopback = Config.ForceLoopback}));
+ }
+#endif
+ else if (ServerClass == "null"sv)
+ {
+ ZEN_INFO("using null HTTP server implementation")
+ return Ref<HttpServer>(new HttpNullServer);
+ }
+ else
+ {
+ ZEN_WARN("unknown HTTP server implementation '{}', falling back to default", ServerClass)
- return Server;
- }
+#if ZEN_WITH_HTTPSYS
+ return CreateHttpServerClass("httpsys"sv, Config);
+#else
+ return CreateHttpServerClass("asio"sv, Config);
+#endif
+ }
+}
#if ZEN_WITH_PLUGINS
- case HttpServerClass::kHttpPlugin:
- {
- ZEN_INFO("using plugin HTTP server implementation");
- Ref<HttpPluginServer> Server{CreateHttpPluginServer()};
+Ref<HttpServer>
+CreateHttpServerPlugin(const HttpServerPluginConfig& PluginConfig)
+{
+ const std::string& PluginName = PluginConfig.PluginName;
- // This is hardcoded for now, but should be configurable in the future
+ ZEN_INFO("using '{}' plugin HTTP server implementation", PluginName)
+ if (PluginName.starts_with("builtin:"sv))
+ {
# if 0
- Ref<TransportPlugin> WinsockPlugin{CreateSocketTransportPlugin()};
- WinsockPlugin->Configure("port", "8558");
- Server->AddPlugin(WinsockPlugin);
-# endif
+ Ref<TransportPlugin> Plugin = {};
+ if (PluginName == "builtin:winsock"sv)
+ {
+ Plugin = CreateSocketTransportPlugin();
+ }
+ else if (PluginName == "builtin:asio"sv)
+ {
+ Plugin = CreateAsioTransportPlugin();
+ }
+ else
+ {
+ ZEN_WARN("Unknown builtin plugin '{}'", PluginName)
+ return {};
+ }
-# if 0
- Ref<TransportPlugin> AsioPlugin{CreateAsioTransportPlugin()};
- AsioPlugin->Configure("port", "8558");
- Server->AddPlugin(AsioPlugin);
-# endif
+ ZEN_ASSERT(!Plugin.IsNull());
-# if 1
- Ref<DllTransportPlugin> DllPlugin{CreateDllTransportPlugin()};
- DllPlugin->LoadDll("winsock");
- DllPlugin->ConfigureDll("winsock", "port", "8558");
- Server->AddPlugin(DllPlugin);
-# endif
+ for (const std::pair<std::string, std::string>& Option : PluginConfig.PluginOptions)
+ {
+ Plugin->Configure(Option.first.c_str(), Option.second.c_str());
+ }
- return Server;
- }
-#endif
+ Ref<HttpPluginServer> Server{CreateHttpPluginServer()};
+ Server->AddPlugin(Plugin);
+ return Server;
+# else
+ ZEN_WARN("Builtin plugin '{}' is not supported", PluginName)
+ return {};
+# endif
+ }
-#if ZEN_WITH_HTTPSYS
- case HttpServerClass::kHttpSys:
- ZEN_INFO("using http.sys server implementation");
- return Ref<HttpServer>(CreateHttpSysServer({.ThreadCount = Config.ThreadCount,
- .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount,
- .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled,
- .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled,
- .IsDedicatedServer = Config.IsDedicatedServer,
- .ForceLoopback = Config.ForceLoopback}));
-#endif
+ Ref<DllTransportPlugin> DllPlugin{CreateDllTransportPlugin()};
+ if (!DllPlugin->LoadDll(PluginName))
+ {
+ return {};
+ }
- case HttpServerClass::kHttpNull:
- ZEN_INFO("using null HTTP server implementation");
- return Ref<HttpServer>(new HttpNullServer);
+ for (const std::pair<std::string, std::string>& Option : PluginConfig.PluginOptions)
+ {
+ DllPlugin->ConfigureDll(PluginName, Option.first.c_str(), Option.second.c_str());
}
+
+ Ref<HttpPluginServer> Server{CreateHttpPluginServer()};
+ Server->AddPlugin(DllPlugin);
+ return Server;
}
+#endif
Ref<HttpServer>
CreateHttpServer(const HttpServerConfig& Config)
{
using namespace std::literals;
- HttpServerClass Class = HttpServerClass::kHttpNull;
-
-#if ZEN_WITH_HTTPSYS
- Class = HttpServerClass::kHttpSys;
-#else
- Class = HttpServerClass::kHttpAsio;
-#endif
-
- if (Config.ServerClass == "asio"sv)
- {
- Class = HttpServerClass::kHttpAsio;
- }
- else if (Config.ServerClass == "httpsys"sv)
- {
- Class = HttpServerClass::kHttpSys;
- }
- else if (Config.ServerClass == "plugin"sv)
- {
- Class = HttpServerClass::kHttpPlugin;
- }
- else if (Config.ServerClass == "null"sv)
+#if ZEN_WITH_PLUGINS
+ if (Config.PluginConfigs.empty())
{
- Class = HttpServerClass::kHttpNull;
+ return CreateHttpServerClass(Config.ServerClass, Config);
}
- else if (Config.ServerClass == "multi"sv)
+ else
{
- Class = HttpServerClass::kHttpMulti;
- }
+ Ref<HttpMultiServer> Server{new HttpMultiServer()};
+ Server->AddServer(CreateHttpServerClass(Config.ServerClass, Config));
- return CreateHttpServerClass(Class, Config);
+ for (const HttpServerPluginConfig& PluginConfig : Config.PluginConfigs)
+ {
+ Ref<HttpServer> PluginServer = CreateHttpServerPlugin(PluginConfig);
+ if (!PluginServer.IsNull())
+ {
+ Server->AddServer(PluginServer);
+ }
+ }
+
+ return Server;
+ }
+#else
+ return CreateHttpServerClass(Config.ServerClass, Config);
+#endif
}
//////////////////////////////////////////////////////////////////////////
@@ -865,42 +936,51 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
if (PackageHandlerRef)
{
- CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload());
-
- std::vector<IoHash> OfferCids;
-
- for (auto& CidEntry : OfferMessage["offer"])
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject OfferMessage = ValidateAndReadCompactBinaryObject(IoBuffer(Request.ReadPayload()), ValidationError);
+ ValidationError == CbValidateError::None)
{
- if (!CidEntry.IsHash())
+ std::vector<IoHash> OfferCids;
+
+ for (auto& CidEntry : OfferMessage["offer"])
{
- // Should yield bad request response?
+ if (!CidEntry.IsHash())
+ {
+ // Should yield bad request response?
+
+ ZEN_WARN("found invalid entry in offer");
- ZEN_WARN("found invalid entry in offer");
+ continue;
+ }
- continue;
+ OfferCids.push_back(CidEntry.AsHash());
}
- OfferCids.push_back(CidEntry.AsHash());
- }
+ ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size());
- ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size());
+ PackageHandlerRef->FilterOffer(OfferCids);
- PackageHandlerRef->FilterOffer(OfferCids);
+ ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size());
- ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size());
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
- CbObjectWriter ResponseWriter;
- ResponseWriter.BeginArray("need");
+ for (const IoHash& Cid : OfferCids)
+ {
+ ResponseWriter.AddHash(Cid);
+ }
+
+ ResponseWriter.EndArray();
- for (const IoHash& Cid : OfferCids)
+ // Emit filter response
+ Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ else
{
- ResponseWriter.AddHash(Cid);
+ Request.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Invalid request payload: '{}'", ToString(ValidationError)));
}
-
- ResponseWriter.EndArray();
-
- // Emit filter response
- Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
return true;
}
}
diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h
index f43ae3cd7..6f9c3198e 100644
--- a/src/zenhttp/include/zenhttp/auth/oidc.h
+++ b/src/zenhttp/include/zenhttp/auth/oidc.h
@@ -2,13 +2,14 @@
#pragma once
+#include <zenbase/refcount.h>
#include <zencore/string.h>
#include <vector>
namespace zen {
-class OidcClient
+class OidcClient : public RefCounted
{
public:
struct Options
diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h
new file mode 100644
index 000000000..a3b870c0f
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/cprutils.h
@@ -0,0 +1,86 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+#include <zenhttp/formatters.h>
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpcommon.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/response.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+template<>
+struct fmt::formatter<cpr::Response>
+{
+ constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); }
+
+ template<typename FormatContext>
+ auto format(const cpr::Response& Response, FormatContext& Ctx) const -> decltype(Ctx.out())
+ {
+ using namespace std::literals;
+
+ zen::NiceTimeSpanMs NiceResponseTime(uint64_t(Response.elapsed * 1000));
+
+ if (zen::IsHttpSuccessCode(Response.status_code))
+ {
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Error: '{}' ({}), Bytes: {}/{} (Up/Down), Elapsed: {}",
+ Response.url.str(),
+ Response.status_code,
+ Response.error.message,
+ int(Response.error.code),
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ NiceResponseTime.c_str());
+ }
+ else
+ {
+ const auto It = Response.header.find("Content-Type");
+ const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv;
+
+ if (ContentType == "application/x-ue-cb"sv)
+ {
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ zen::CbObjectView Obj(Body.Data());
+ zen::ExtendableStringBuilder<256> Sb;
+ std::string_view Json = Obj.ToJson(Sb).ToView();
+
+ return fmt::format_to(
+ Ctx.out(),
+ "Url: {}, Status: {}, Error: '{}' ({}). Bytes: {}/{} (Up/Down), Elapsed: {}, Response: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.error.message,
+ int(Response.error.code),
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ NiceResponseTime.c_str(),
+ Json,
+ Response.reason);
+ }
+ else
+ {
+ zen::BodyLogFormatter Body(Response.text);
+
+ return fmt::format_to(
+ Ctx.out(),
+ "Url: {}, Status: {}, Error: '{}' ({}), Bytes: {}/{} (Up/Down), Elapsed: {}, Response: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.error.message,
+ int(Response.error.code),
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ NiceResponseTime.c_str(),
+ Body.GetText(),
+ Response.reason);
+ }
+ }
+ }
+};
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index d45f5fbb2..0af31fa30 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -7,77 +7,63 @@
#include <zencore/iobuffer.h>
#include <zencore/string.h>
#include <zenhttp/httpclient.h>
+#include <zenhttp/httpcommon.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <cpr/cpr.h>
#include <fmt/format.h>
ZEN_THIRD_PARTY_INCLUDES_END
-template<>
-struct fmt::formatter<cpr::Response>
+namespace zen {
+
+struct BodyLogFormatter
{
- constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); }
+private:
+ std::string_view ResponseText;
+ zen::ExtendableStringBuilder<128> ModifiedResponse;
- template<typename FormatContext>
- auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out())
+public:
+ explicit BodyLogFormatter(std::string_view InResponseText) : ResponseText(InResponseText)
{
using namespace std::literals;
- if (Response.status_code == 200 || Response.status_code == 201)
+ const int TextSizeLimit = 1024;
+
+ // Trim invalid UTF8
+
+ auto InvalidIt = zen::FindFirstInvalidUtf8Byte(ResponseText);
+
+ if (InvalidIt != end(ResponseText))
{
- return fmt::format_to(Ctx.out(),
- "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
- Response.url.str(),
- Response.status_code,
- Response.uploaded_bytes,
- Response.downloaded_bytes,
- Response.elapsed);
+ ResponseText = ResponseText.substr(0, InvalidIt - begin(ResponseText));
}
- else
+
+ if (ResponseText.empty())
+ {
+ ResponseText = "<suppressed non-text response>"sv;
+ }
+
+ if (ResponseText.size() > TextSizeLimit)
{
- const auto It = Response.header.find("Content-Type");
- const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv;
-
- if (ContentType == "application/x-ue-cb"sv)
- {
- zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
- zen::CbObjectView Obj(Body.Data());
- zen::ExtendableStringBuilder<256> Sb;
- std::string_view Json = Obj.ToJson(Sb).ToView();
-
- return fmt::format_to(Ctx.out(),
- "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'",
- Response.url.str(),
- Response.status_code,
- Response.uploaded_bytes,
- Response.downloaded_bytes,
- Response.elapsed,
- Json,
- Response.reason);
- }
- else
- {
- return fmt::format_to(Ctx.out(),
- "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'",
- Response.url.str(),
- Response.status_code,
- Response.uploaded_bytes,
- Response.downloaded_bytes,
- Response.elapsed,
- Response.text,
- Response.reason);
- }
+ const auto TruncatedString = "[truncated response] "sv;
+ ModifiedResponse.Append(TruncatedString);
+ ModifiedResponse.Append(ResponseText.data(), TextSizeLimit - TruncatedString.size());
+
+ ResponseText = ModifiedResponse;
}
}
+
+ inline std::string_view GetText() const { return ResponseText; }
};
+} // namespace zen
+
template<>
struct fmt::formatter<zen::HttpClient::Response>
{
constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); }
template<typename FormatContext>
- auto format(const zen::HttpClient::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out())
+ auto format(const zen::HttpClient::Response& Response, FormatContext& Ctx) const -> decltype(Ctx.out())
{
using namespace std::literals;
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 9de5c7cce..c1fc1efa6 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -6,9 +6,11 @@
#include <zencore/iobuffer.h>
#include <zencore/logbase.h>
+#include <zencore/thread.h>
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <functional>
#include <optional>
#include <unordered_map>
@@ -27,20 +29,91 @@ class CompositeBuffer;
*/
+struct HttpClientAccessToken
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+
+ static constexpr int64_t ExpireMarginInSeconds = 60 * 5;
+
+ std::string Value;
+ TimePoint ExpireTime;
+
+ bool IsValid() const
+ {
+ return Value.empty() == false &&
+ ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count();
+ }
+};
+
struct HttpClientSettings
{
- std::string LogCategory = "httpclient";
- std::chrono::milliseconds ConnectTimeout{3000};
- std::chrono::milliseconds Timeout{};
- bool AssumeHttp2 = false;
+ std::string LogCategory = "httpclient";
+ std::chrono::milliseconds ConnectTimeout{3000};
+ std::chrono::milliseconds Timeout{};
+ std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+ bool AssumeHttp2 = false;
+ bool AllowResume = false;
+ uint8_t RetryCount = 0;
+ Oid SessionId = Oid::Zero;
};
-class HttpClient
+class HttpClientError : public std::runtime_error
{
public:
- struct Settings
+ using _Mybase = runtime_error;
+
+ HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode)
+ : _Mybase(Message)
+ , m_Error(Error)
+ , m_ResponseCode(ResponseCode)
{
+ }
+
+ HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode)
+ : _Mybase(Message)
+ , m_Error(Error)
+ , m_ResponseCode(ResponseCode)
+ {
+ }
+
+ inline int GetInternalErrorCode() const { return m_Error; }
+ inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
+
+ enum class ResponseClass : std::int8_t
+ {
+ kSuccess = 0,
+
+ kHttpOtherClientError = 80,
+ kHttpCantConnectError = 81, // CONNECTION_FAILURE
+ kHttpNotFound = 66, // NotFound(404)
+ kHttpUnauthorized = 77, // Unauthorized(401),
+ kHttpSLLError =
+ 82, // SSL_CONNECT_ERROR, SSL_LOCAL_CERTIFICATE_ERROR, SSL_REMOTE_CERTIFICATE_ERROR, SSL_CACERT_ERROR, GENERIC_SSL_ERROR
+ kHttpForbidden = 83, // Forbidden(403)
+ kHttpTimeout = 84, // NETWORK_RECEIVE_ERROR, NETWORK_SEND_FAILURE, OPERATION_TIMEDOUT, RequestTimeout(408)
+ kHttpConflict = 85, // Conflict(409)
+ kHttpNoHost = 86, // HOST_RESOLUTION_FAILURE, PROXY_RESOLUTION_FAILURE
+
+ kHttpOtherServerError = 90,
+ kHttpInternalServerError = 91, // InternalServerError(500)
+ kHttpServiceUnavailable = 69, // ServiceUnavailable(503)
+ kHttpBadGateway = 92, // BadGateway(502)
+ kHttpGatewayTimeout = 93, // GatewayTimeout(504)
};
+
+ ResponseClass GetResponseClass() const;
+
+private:
+ const int m_Error = 0;
+ const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
+};
+
+class HttpClientBase;
+
+class HttpClient
+{
+public:
HttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {});
~HttpClient();
@@ -78,7 +151,7 @@ public:
HttpResponseCode StatusCode = HttpResponseCode::ImATeapot;
IoBuffer ResponsePayload; // Note: this also includes the content type
- // Contains the reponse headers
+ // Contains the response headers
KeyValueMap Header;
// The number of bytes sent as part of the request
@@ -121,41 +194,54 @@ public:
std::string ErrorMessage(std::string_view Prefix) const;
};
+ static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType)
+ {
+ return std::make_pair("Accept", MapContentTypeToString(ContentType));
+ }
+
[[nodiscard]] Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {});
+ [[nodiscard]] Response Put(std::string_view Url, const KeyValueMap& Parameters = {});
[[nodiscard]] Response Get(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, const KeyValueMap& Parameters = {});
[[nodiscard]] Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Post(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, const KeyValueMap& Parameters = {});
[[nodiscard]] Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {});
+ [[nodiscard]] Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {});
+ [[nodiscard]] Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response Upload(std::string_view Url,
const CompositeBuffer& Payload,
ZenContentType ContentType,
const KeyValueMap& AdditionalHeader = {});
+
[[nodiscard]] Response Download(std::string_view Url,
const std::filesystem::path& TempFolderPath,
const KeyValueMap& AdditionalHeader = {});
[[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {});
- static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType)
- {
- return std::make_pair("Accept", MapContentTypeToString(ContentType));
- }
-
- LoggerRef Logger() { return m_Log; }
+ LoggerRef Log() { return m_Log; }
std::string_view GetBaseUri() const { return m_BaseUri; }
+ std::string_view GetSessionId() const { return m_SessionId; }
+ void SetSessionId(const Oid& SessionId);
+
+ bool Authenticate();
private:
- struct Impl;
+ HttpClientBase* m_Inner;
LoggerRef m_Log;
std::string m_BaseUri;
std::string m_SessionId;
const HttpClientSettings m_ConnectionSettings;
- Ref<Impl> m_Impl;
};
void httpclient_forcelink(); // internal
diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h
new file mode 100644
index 000000000..26f31ed2a
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpclientauth.h
@@ -0,0 +1,36 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpclient.h>
+#include <optional>
+
+namespace zen {
+
+class AuthMgr;
+
+namespace httpclientauth {
+ std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token);
+
+ std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token);
+
+ struct OAuthClientCredentialsParams
+ {
+ std::string_view Url;
+ std::string_view ClientId;
+ std::string_view ClientSecret;
+ };
+
+ std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params);
+
+ std::function<HttpClientAccessToken()> CreateFromOpenIdProvider(AuthMgr& AuthManager, std::string_view OpenIdProvider);
+ std::function<HttpClientAccessToken()> CreateFromDefaultOpenIdProvider(AuthMgr& AuthManager);
+
+ std::optional<std::function<HttpClientAccessToken()>> CreateFromOidcTokenExecutable(const std::filesystem::path& OidcExecutablePath,
+ std::string_view CloudHost,
+ bool Quiet,
+ bool Unattended,
+ bool Hidden);
+} // namespace httpclientauth
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 1089dd221..03e547bf3 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -62,6 +62,8 @@ public:
}
};
+ static std::string Decode(std::string_view PercentEncodedString);
+
virtual bool TryGetRanges(HttpRanges&) { return false; }
QueryParams GetQueryParams();
@@ -182,12 +184,19 @@ public:
virtual void Close() = 0;
};
+struct HttpServerPluginConfig
+{
+ std::string PluginName;
+ std::vector<std::pair<std::string, std::string>> PluginOptions;
+};
+
struct HttpServerConfig
{
- bool IsDedicatedServer = false; // Should be set to true for shared servers
- std::string ServerClass; // Choice of HTTP server implementation
- bool ForceLoopback = false;
- unsigned int ThreadCount = 0;
+ bool IsDedicatedServer = false; // Should be set to true for shared servers
+ std::string ServerClass; // Choice of HTTP server implementation
+ std::vector<HttpServerPluginConfig> PluginConfigs;
+ bool ForceLoopback = false;
+ unsigned int ThreadCount = 0;
struct
{
@@ -206,7 +215,7 @@ class HttpRouterRequest
public:
HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
- ZENCORE_API std::string GetCapture(uint32_t Index) const;
+ std::string_view GetCapture(uint32_t Index) const;
inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
private:
@@ -218,12 +227,14 @@ private:
friend class HttpRequestRouter;
};
-inline std::string
+inline std::string_view
HttpRouterRequest::GetCapture(uint32_t Index) const
{
ZEN_ASSERT(Index < m_Match.size());
- return m_Match[Index];
+ const auto& Match = m_Match[Index];
+
+ return std::string_view(&*Match.first, Match.second - Match.first);
}
/** HTTP request router helper
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
new file mode 100644
index 000000000..c90b840da
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -0,0 +1,164 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+
+#include <functional>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+class IoBuffer;
+class CbPackage;
+class CompositeBuffer;
+
+/** _____ _ _____ _
+ / ____| | | __ \ | |
+ | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___
+ | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \
+ | |____| |_) | | | (_| | (__| < (_| | (_| | __/
+ \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___|
+ __/ |
+ |___/
+
+ Structures and code related to handling CbPackage transactions
+
+ CbPackage instances are marshaled across the wire using a distinct message
+ format. We don't use the CbPackage serialization format provided by the
+ CbPackage implementation itself since that does not provide much flexibility
+ in how the attachment payloads are transmitted. The scheme below separates
+ metadata cleanly from payloads and this enables us to more efficiently
+ transmit them either via sendfile/TransmitFile like mechanisms, or by
+ reference/memory mapping in the local case.
+ */
+
+struct CbPackageHeader
+{
+ uint32_t HeaderMagic;
+ uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document?
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+};
+
+static_assert(sizeof(CbPackageHeader) == 16);
+
+enum : uint32_t
+{
+ kCbPkgMagic = 0xaa77aacc
+};
+
+struct CbAttachmentEntry
+{
+ uint64_t PayloadSize; // Size of the associated payload data in the message
+ uint32_t Flags; // See flags below
+ IoHash AttachmentHash; // Content Id for the attachment
+
+ enum
+ {
+ kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format
+ kIsObject = (1u << 1), // Is compact binary object
+ kIsError = (1u << 2), // Is error (compact binary formatted) object
+ kIsLocalRef = (1u << 3), // Is "local reference"
+ };
+};
+
+struct CbAttachmentReferenceHeader
+{
+ uint64_t PayloadByteOffset = 0;
+ uint64_t PayloadByteSize = ~0u;
+ uint16_t AbsolutePathLength = 0;
+
+ // This header will be followed by UTF8 encoded absolute path to backing file
+};
+
+static_assert(sizeof(CbAttachmentEntry) == 32);
+
+enum class FormatFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0),
+ kDenyPartialLocalReferences = (1u << 1)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags);
+
+enum class RpcAcceptOptions : uint16_t
+{
+ kNone = 0,
+ kAllowLocalReferences = (1u << 0),
+ kAllowPartialLocalReferences = (1u << 1),
+ kAllowPartialCacheChunks = (1u << 2)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+ return IoBuffer{Size};
+ });
+bool IsPackageMessage(IoBuffer Payload);
+
+bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, void* TargetProcessHandle = nullptr);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* TargetProcessHandle = nullptr);
+
+/** Streaming reader for compact binary packages
+
+ The goal is to ultimately support zero-copy I/O, but for now there'll be some
+ copying involved on some platforms at least.
+
+ This approach to deserializing CbPackage data is more efficient than
+ `ParsePackageMessage` since it does not require the entire message to
+ be resident in a memory buffer
+
+ */
+class CbPackageReader
+{
+public:
+ CbPackageReader();
+ ~CbPackageReader();
+
+ void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+
+ /** Process compact binary package data stream
+
+ The data stream must be in the serialization format produced by FormatPackageMessage
+
+ \return How many bytes must be fed to this function in the next call
+ */
+ uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes);
+
+ void Finalize();
+ const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; }
+ CbObject GetRootObject() { return m_RootObject; }
+ std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; }
+
+private:
+ enum class State
+ {
+ kInitialState,
+ kReadingHeader,
+ kReadingAttachmentEntries,
+ kReadingBuffers
+ } m_CurrentState = State::kInitialState;
+
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
+ std::vector<IoBuffer> m_PayloadBuffers;
+ std::vector<CbAttachmentEntry> m_AttachmentEntries;
+ std::vector<CbAttachment> m_Attachments;
+ CbObject m_RootObject;
+ CbPackageHeader m_PackageHeader;
+
+ IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer);
+};
+
+void forcelink_packageformat();
+
+} // namespace zen
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
new file mode 100644
index 000000000..708238224
--- /dev/null
+++ b/src/zenhttp/packageformat.cpp
@@ -0,0 +1,936 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/packageformat.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+#include <zencore/trace.h>
+
+#include <span>
+#include <vector>
+
+#include <EASTL/fixed_vector.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+const std::string_view HandlePrefix(":?#:");
+
+typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t;
+
+IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle);
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, void* TargetProcessHandle)
+{
+ return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessHandle);
+}
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, void* TargetProcessHandle)
+{
+ return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessHandle);
+}
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle)
+{
+ auto Vec = FormatPackageMessageInternal(Data, Flags, TargetProcessHandle);
+ return std::vector<IoBuffer>(begin(Vec), end(Vec));
+}
+
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle)
+{
+ auto Vec = FormatPackageMessageInternal(Data, Flags, TargetProcessHandle);
+ return CompositeBuffer(std::span{begin(Vec), end(Vec)});
+}
+
+static void
+MarshalLocal(CbAttachmentEntry*& AttachmentInfo,
+ const std::string& Path8,
+ CbAttachmentReferenceHeader& LocalRef,
+ const IoHash& AttachmentHash,
+ bool IsCompressed,
+ IoBufferVec_t& ResponseBuffers)
+{
+ IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size());
+
+ CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>();
+ *RefHdr++ = LocalRef;
+ memcpy(RefHdr, Path8.data(), Path8.size());
+
+ *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(),
+ .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef,
+ .AttachmentHash = AttachmentHash};
+
+ ResponseBuffers.emplace_back(std::move(RefBuffer));
+};
+
+static bool
+IsLocalRef(tsl::robin_map<void*, std::string>& FileNameMap,
+ std::vector<void*>& DuplicatedHandles,
+ const CompositeBuffer& AttachmentBinary,
+ bool DenyPartialLocalReferences,
+ void* TargetProcessHandle,
+ CbAttachmentReferenceHeader& LocalRef,
+ std::string& Path8)
+{
+ const SharedBuffer& Segment = AttachmentBinary.GetSegments().front();
+ IoBufferFileReference Ref;
+ const IoBuffer& SegmentBuffer = Segment.AsIoBuffer();
+
+ if (!SegmentBuffer.GetFileReference(Ref))
+ {
+ return false;
+ }
+
+ if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile())
+ {
+ return false;
+ }
+
+ if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end())
+ {
+ Path8 = It->second;
+ }
+ else
+ {
+ bool UseFilePath = true;
+#if ZEN_PLATFORM_WINDOWS
+ if (TargetProcessHandle != nullptr)
+ {
+ HANDLE TargetHandle = INVALID_HANDLE_VALUE;
+ BOOL OK = ::DuplicateHandle(GetCurrentProcess(),
+ Ref.FileHandle,
+ (HANDLE)TargetProcessHandle,
+ &TargetHandle,
+ FILE_GENERIC_READ,
+ FALSE,
+ 0);
+ if (OK)
+ {
+ DuplicatedHandles.push_back((void*)TargetHandle);
+ Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle));
+ UseFilePath = false;
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ ZEN_UNUSED(TargetProcessHandle);
+ ZEN_UNUSED(DuplicatedHandles);
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to
+ // deal with access rights etc.
+#endif // ZEN_PLATFORM_WINDOWS
+ if (UseFilePath)
+ {
+ ExtendablePathBuilder<256> LocalRefFile;
+ std::error_code Ec;
+ std::filesystem::path FilePath = PathFromHandle(Ref.FileHandle, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to get path for file handle {} in IsLocalRef check, reason '{}'", Ref.FileHandle, Ec.message());
+ return false;
+ }
+ LocalRefFile.Append(std::filesystem::absolute(FilePath));
+ Path8 = LocalRefFile.ToUtf8();
+ }
+ FileNameMap.insert_or_assign(Ref.FileHandle, Path8);
+ }
+
+ LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size());
+ LocalRef.PayloadByteOffset = Ref.FileChunkOffset;
+ LocalRef.PayloadByteSize = Ref.FileChunkSize;
+
+ return true;
+};
+
+IoBufferVec_t
+FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle)
+{
+ ZEN_TRACE_CPU("FormatPackageMessage");
+
+ std::vector<void*> DuplicatedHandles;
+#if ZEN_PLATFORM_WINDOWS
+ auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() {
+ if (TargetProcessHandle == nullptr)
+ {
+ return;
+ }
+
+ for (void* DuplicatedHandle : DuplicatedHandles)
+ {
+ HANDLE ClosingHandle;
+ if (::DuplicateHandle((HANDLE)TargetProcessHandle,
+ (HANDLE)DuplicatedHandle,
+ GetCurrentProcess(),
+ &ClosingHandle,
+ 0,
+ FALSE,
+ DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE)
+ {
+ ::CloseHandle(ClosingHandle);
+ }
+ }
+ });
+#endif // ZEN_PLATFORM_WINDOWS
+
+ const std::span<const CbAttachment>& Attachments = Data.GetAttachments();
+ IoBufferVec_t ResponseBuffers;
+
+ ResponseBuffers.reserve(2 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each
+ // attachment is likely to consist of several buffers
+
+ IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)};
+ MutableMemoryView HeaderView = AttachmentMetadataBuffer.GetMutableView();
+ // Fixed size header
+
+ CbPackageHeader* Hdr = (CbPackageHeader*)HeaderView.GetData();
+ *Hdr = {.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())};
+ HeaderView.MidInline(sizeof(CbPackageHeader));
+
+ // Attachment metadata array
+ CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(HeaderView.GetData());
+ ResponseBuffers.emplace_back(std::move(AttachmentMetadataBuffer)); // Attachment metadata
+
+ // Root object
+
+ IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer();
+ ZEN_ASSERT(RootIoBuffer.GetSize() > 0);
+ *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()};
+ ResponseBuffers.emplace_back(std::move(RootIoBuffer)); // Root object
+
+ // Attachment payloads
+ tsl::robin_map<void*, std::string> FileNameMap;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ if (Attachment.IsNull())
+ {
+ ZEN_NOT_IMPLEMENTED("Null attachments are not supported");
+ }
+ else if (const CompressedBuffer& AttachmentBuffer = Attachment.AsCompressedBinary())
+ {
+ const CompositeBuffer& Compressed = AttachmentBuffer.GetCompressed();
+ IoHash AttachmentHash = Attachment.GetHash();
+
+ // If the data is either not backed by a file, or there are multiple
+ // fragments then we cannot marshal it by local reference. We might
+ // want/need to extend this in the future to allow multiple chunk
+ // segments to be marshaled at once
+
+ bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(FileNameMap,
+ DuplicatedHandles,
+ Compressed,
+ DenyPartialLocalReferences,
+ TargetProcessHandle,
+ LocalRef,
+ Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = true;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(AttachmentInfo, Path8, LocalRef, AttachmentHash, IsCompressed, ResponseBuffers);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(),
+ .Flags = CbAttachmentEntry::kIsCompressed,
+ .AttachmentHash = AttachmentHash};
+
+ std::span<const SharedBuffer> Segments = Compressed.GetSegments();
+ ResponseBuffers.reserve(ResponseBuffers.size() + Segments.size() - 1);
+ for (const SharedBuffer& Segment : Segments)
+ {
+ ZEN_ASSERT(Segment.GetSize() > 0);
+ ResponseBuffers.emplace_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else if (CbObject AttachmentObject = Attachment.AsObject())
+ {
+ IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer();
+ ZEN_ASSERT(ObjIoBuffer.GetSize() > 0);
+ *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(),
+ .Flags = CbAttachmentEntry::kIsObject,
+ .AttachmentHash = Attachment.GetHash()};
+ ResponseBuffers.emplace_back(std::move(ObjIoBuffer));
+ }
+ else if (const CompositeBuffer& AttachmentBinary = Attachment.AsCompositeBinary())
+ {
+ IoHash AttachmentHash = Attachment.GetHash();
+ bool MarshalByLocalRef =
+ EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(FileNameMap,
+ DuplicatedHandles,
+ AttachmentBinary,
+ DenyPartialLocalReferences,
+ TargetProcessHandle,
+ LocalRef,
+ Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = false;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(AttachmentInfo, Path8, LocalRef, AttachmentHash, IsCompressed, ResponseBuffers);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()};
+
+ std::span<const SharedBuffer> Segments = AttachmentBinary.GetSegments();
+ ResponseBuffers.reserve(ResponseBuffers.size() + Segments.size() - 1);
+ for (const SharedBuffer& Segment : Segments)
+ {
+ ZEN_ASSERT(Segment.GetSize() > 0);
+ ResponseBuffers.emplace_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Unknown attachment kind");
+ }
+ }
+ FileNameMap.clear();
+#if ZEN_PLATFORM_WINDOWS
+ DuplicatedHandles.clear();
+#endif // ZEN_PLATFORM_WINDOWS
+
+ return ResponseBuffers;
+}
+
+bool
+IsPackageMessage(IoBuffer Payload)
+{
+ if (Payload.GetSize() < sizeof(CbPackageHeader))
+ {
+ return false;
+ }
+
+ BinaryReader Reader(Payload);
+ const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
+ if (Hdr->HeaderMagic != kCbPkgMagic)
+ {
+ return false;
+ }
+
+ return true;
+}
+
+CbPackage
+ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+{
+ ZEN_TRACE_CPU("ParsePackageMessage");
+
+ if (Payload.GetSize() < sizeof(CbPackageHeader))
+ {
+ throw std::invalid_argument(fmt::format("invalid CbPackage, missing complete header (size {})", Payload.GetSize()));
+ }
+
+ BinaryReader Reader(Payload);
+
+ const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
+ if (Hdr->HeaderMagic != kCbPkgMagic)
+ {
+ throw std::invalid_argument(
+ fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic));
+ }
+ Reader.Skip(sizeof(CbPackageHeader));
+
+ const uint32_t ChunkCount = Hdr->AttachmentCount + 1;
+
+ if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount)
+ {
+ throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)",
+ sizeof(CbAttachmentEntry) * ChunkCount,
+ Reader.Remaining()));
+ }
+ const CbAttachmentEntry* AttachmentEntries =
+ reinterpret_cast<const CbAttachmentEntry*>(Reader.GetView(sizeof(CbAttachmentEntry) * ChunkCount).GetData());
+ Reader.Skip(sizeof(CbAttachmentEntry) * ChunkCount);
+
+ CbPackage Package;
+
+ std::vector<CbAttachment> Attachments;
+ Attachments.reserve(ChunkCount); // Guessing here...
+
+ tsl::robin_map<std::string, IoBuffer> PartialFileBuffers;
+
+ std::vector<std::pair<uint32_t, std::string>> MalformedAttachments;
+
+ for (uint32_t i = 0; i < ChunkCount; ++i)
+ {
+ const CbAttachmentEntry& Entry = AttachmentEntries[i];
+ const uint64_t AttachmentSize = Entry.PayloadSize;
+
+ if (Reader.Remaining() < AttachmentSize)
+ {
+ throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment data (need {} bytes, have {} bytes)",
+ AttachmentSize,
+ Reader.Remaining()));
+ }
+ const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize);
+ Reader.Skip(AttachmentSize);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+
+ IoBuffer FullFileBuffer;
+
+ std::filesystem::path Path(Utf8ToWide(PathView));
+ if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end())
+ {
+ FullFileBuffer = It->second;
+ }
+ else
+ {
+ if (PathView.starts_with(HandlePrefix))
+ {
+#if ZEN_PLATFORM_WINDOWS
+ std::string_view HandleString(PathView.substr(HandlePrefix.length()));
+ std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString));
+ if (HandleNumber.has_value())
+ {
+ HANDLE FileHandle = HANDLE(HandleNumber.value());
+ ULARGE_INTEGER liFileSize;
+ liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart);
+ if (liFileSize.LowPart != INVALID_FILE_SIZE)
+ {
+ FullFileBuffer =
+ IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart), /*IsWholeFile*/ true);
+ PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer);
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes
+ // and to deal with acceess rights etc.
+ ZEN_ASSERT(false);
+#endif // ZEN_PLATFORM_WINDOWS
+ }
+ else
+ {
+ FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
+ }
+ }
+
+ if (FullFileBuffer)
+ {
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ ? FullFileBuffer
+ : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference)));
+ if (CompBuf)
+ {
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ else
+ {
+ MalformedAttachments.push_back(std::make_pair(i,
+ fmt::format("Invalid format in '{}' (offset {}, size {}) for {}",
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ Entry.AttachmentHash)));
+ }
+ }
+ else
+ {
+ MalformedAttachments.push_back(std::make_pair(i,
+ fmt::format("Unable to resolve chunk at '{}' (offset {}, size {}) for {}",
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ Entry.AttachmentHash)));
+ }
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ CbObject AttachmentObject;
+
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (!CompBuf)
+ {
+ // First payload is always a compact binary object
+ MalformedAttachments.push_back(
+ std::make_pair(i,
+ fmt::format("Invalid format, expected compressed buffer for CbObject (size {}) for {}",
+ AttachmentBuffer.GetSize(),
+ Entry.AttachmentHash)));
+ }
+ else
+ {
+ CbValidateError ValidationError = CbValidateError::None;
+ AttachmentObject = ValidateAndReadCompactBinaryObject(std::move(CompBuf), ValidationError);
+ if (ValidationError != CbValidateError::None)
+ {
+ MalformedAttachments.push_back(std::make_pair(
+ i,
+ fmt::format("Invalid format, CbObject for {}. Reason '{}'", Entry.AttachmentHash, ToString(ValidationError))));
+ }
+ }
+
+ if (i == 0)
+ {
+ // First payload is always a compact binary object
+ Package.SetObject(AttachmentObject);
+ }
+ else
+ {
+ Attachments.emplace_back(CbAttachment(AttachmentObject, Entry.AttachmentHash));
+ }
+ }
+ else
+ {
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (CompBuf)
+ {
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ else
+ {
+ MalformedAttachments.push_back(
+ std::make_pair(i,
+ fmt::format("Invalid format, expected compressed buffer for attachment (size {}) for {}",
+ AttachmentBuffer.GetSize(),
+ Entry.AttachmentHash)));
+ }
+ }
+ }
+ else /* not compressed */
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ CbValidateError ValidationError = CbValidateError::None;
+ CbObject AttachmentObject = ValidateAndReadCompactBinaryObject(std::move(AttachmentBuffer), ValidationError);
+ if (ValidationError != CbValidateError::None)
+ {
+ MalformedAttachments.push_back(std::make_pair(
+ i,
+ fmt::format("Invalid format, CbObject for {}. Reason '{}'", Entry.AttachmentHash, ToString(ValidationError))));
+ }
+
+ if (i == 0)
+ {
+ Package.SetObject(AttachmentObject);
+ }
+ else
+ {
+ Attachments.emplace_back(CbAttachment(AttachmentObject, Entry.AttachmentHash));
+ }
+ }
+ else if (AttachmentSize > 0)
+ {
+ // Make a copy of the buffer so the attachments don't reference the entire payload
+ IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
+ ZEN_ASSERT(AttachmentBufferCopy);
+ ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
+ AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
+
+ Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ }
+ else
+ {
+ MalformedAttachments.push_back(
+ std::make_pair(i, fmt::format("Invalid format, attachment of size zero detected for {}", Entry.AttachmentHash)));
+ }
+ }
+ }
+ PartialFileBuffers.clear();
+
+ Package.AddAttachments(Attachments);
+
+ using namespace std::literals;
+
+ if (!MalformedAttachments.empty())
+ {
+ ExtendableStringBuilder<1024> SB;
+ SB << (uint64_t)MalformedAttachments.size() << " malformed attachments in package message:\n";
+ for (const auto& It : MalformedAttachments)
+ {
+ SB << " #"sv << It.first << ": " << It.second << "\n";
+ }
+ ZEN_WARN("{}", SB.ToView());
+ throw std::invalid_argument(SB.ToString());
+ }
+
+ return Package;
+}
+
+bool
+ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage)
+{
+ if (IsPackageMessage(Response))
+ {
+ OutPackage = ParsePackageMessage(Response);
+ return true;
+ }
+ return OutPackage.TryLoad(Response);
+}
+
+CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+{
+}
+
+CbPackageReader::~CbPackageReader()
+{
+}
+
+void
+CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer)
+{
+ m_CreateBuffer = CreateBuffer;
+}
+
+uint64_t
+CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
+{
+ ZEN_ASSERT(m_CurrentState != State::kReadingBuffers);
+
+ switch (m_CurrentState)
+ {
+ case State::kInitialState:
+ ZEN_ASSERT(Data == nullptr);
+ m_CurrentState = State::kReadingHeader;
+ return sizeof m_PackageHeader;
+
+ case State::kReadingHeader:
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
+ return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+
+ case State::kReadingAttachmentEntries:
+ ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
+ memcpy(m_AttachmentEntries.data(), Data, DataBytes);
+
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ // This preallocates memory for payloads but note that for the local references
+ // the caller will need to handle the payload differently (i.e it's a
+ // CbAttachmentReferenceHeader not the actual payload)
+
+ m_PayloadBuffers.emplace_back(IoBuffer{Entry.PayloadSize});
+ }
+
+ m_CurrentState = State::kReadingBuffers;
+ return 0;
+
+ default:
+ ZEN_ASSERT(false);
+ return 0;
+ }
+}
+
+IoBuffer
+CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
+{
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+
+ std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+
+ std::filesystem::path Path{PathView};
+
+ IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ if (!ChunkReference)
+ {
+ // Unable to open chunk reference
+
+ throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+
+ return ChunkReference;
+};
+
+void
+CbPackageReader::Finalize()
+{
+ if (m_AttachmentEntries.empty())
+ {
+ return;
+ }
+
+ m_Attachments.reserve(m_AttachmentEntries.size() - 1);
+
+ int CurrentAttachmentIndex = 0;
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+
+ if (CurrentAttachmentIndex == 0)
+ {
+ // Root object
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ CbValidateError ValidateError = CbValidateError::None;
+ m_RootObject = ValidateAndReadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer), ValidateError);
+ if (ValidateError != CbValidateError::None)
+ {
+ throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError)));
+ }
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ CbValidateError ValidateError = CbValidateError::None;
+ m_RootObject = ValidateAndReadCompactBinaryObject(std::move(Compressed), ValidateError);
+ if (ValidateError != CbValidateError::None)
+ {
+ throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError)));
+ }
+ }
+ }
+ else
+ {
+ CbValidateError ValidateError = CbValidateError::None;
+ m_RootObject = ValidateAndReadCompactBinaryObject(std::move(AttachmentBuffer), ValidateError);
+ if (ValidateError != CbValidateError::None)
+ {
+ throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError)));
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error("missing or invalid root object");
+ }
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ m_Attachments.emplace_back(CbAttachment(Compressed, Entry.AttachmentHash));
+ }
+ }
+ else
+ {
+ CompressedBuffer Compressed =
+ CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ m_Attachments.emplace_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash()));
+ }
+ }
+
+ ++CurrentAttachmentIndex;
+ }
+}
+
+/**
+ ______________________ _____________________________
+ \__ ___/\_ _____// _____/\__ ___/ _____/
+ | | | __)_ \_____ \ | | \_____ \
+ | | | \/ \ | | / \
+ |____| /_______ /_______ / |____| /_______ /
+ \/ \/ \/
+ */
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("CbPackage.Serialization")
+{
+ // Make a test package
+
+ CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))};
+ CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.EmptyObject")
+{
+ CbPackage Pkg;
+ Pkg.SetObject({});
+ std::vector<IoBuffer> Result = FormatPackageMessage(Pkg, nullptr);
+}
+
+TEST_CASE("CbPackage.LocalRef")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ {
+ IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd"));
+ IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh"));
+
+ WriteFile(Path1, Buffer1);
+ WriteFile(Path2, Buffer2);
+ }
+
+ // Make a test package
+
+ IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1);
+ IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2);
+
+ CbAttachment Attach1{SharedBuffer(FileBuffer1)};
+ CbAttachment Attach2{SharedBuffer(FileBuffer2)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+void
+forcelink_packageformat()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 9fca314b3..2023b6d98 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1,9 +1,11 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpasio.h"
+#include "httptracer.h"
#include <zencore/except.h>
#include <zencore/logging.h>
+#include <zencore/memory/llm.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zenhttp/httpserver.h>
@@ -31,6 +33,18 @@ ZEN_THIRD_PARTY_INCLUDES_END
# define ZEN_TRACE_VERBOSE(fmtstr, ...)
#endif
+namespace zen {
+
+const FLLMTag&
+GetHttpasioTag()
+{
+ static FLLMTag _("httpasio");
+
+ return _;
+}
+
+} // namespace zen
+
namespace zen::asio_http {
using namespace std::literals;
@@ -62,6 +76,7 @@ public:
HttpAsioServerImpl();
~HttpAsioServerImpl();
+ void Initialize(std::filesystem::path DataDir);
int Start(uint16_t Port, bool ForceLooopback, int ThreadCount);
void Stop();
void RegisterService(const char* UrlPath, HttpService& Service);
@@ -72,6 +87,9 @@ public:
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
std::vector<std::thread> m_ThreadPool;
+ LoggerRef m_RequestLog;
+ HttpServerTracer m_RequestTracer;
+
struct ServiceEntry
{
std::string ServiceUrlPath;
@@ -120,6 +138,8 @@ public:
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_TRACE_CPU("asio::InitializeForPayload");
m_ResponseCode = ResponseCode;
@@ -168,8 +188,8 @@ public:
}
m_ContentLength = LocalDataSize;
- auto Headers = GetHeaders();
- m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
+ std::string_view Headers = GetHeaders();
+ m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
}
uint16_t ResponseCode() const { return m_ResponseCode; }
@@ -179,6 +199,8 @@ public:
std::string_view GetHeaders()
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
<< "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
@@ -293,7 +315,9 @@ HttpServerConnection::TerminateConnection()
void
HttpServerConnection::EnqueueRead()
{
- if (m_RequestState == RequestState::kInitialRead)
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
+ if ((m_RequestState == RequestState::kInitialRead) || (m_RequestState == RequestState::kReadingMore))
{
m_RequestState = RequestState::kReadingMore;
}
@@ -313,17 +337,21 @@ HttpServerConnection::EnqueueRead()
void
HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
if (Ec)
{
- if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead)
+ switch (m_RequestState)
{
- ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message());
- return;
- }
- else
- {
- ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message());
- return TerminateConnection();
+ case RequestState::kDone:
+ case RequestState::kInitialRead:
+ case RequestState::kTerminated:
+ ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message());
+ return;
+
+ default:
+ ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message());
+ return TerminateConnection();
}
}
@@ -362,6 +390,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
void
HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
if (Ec)
{
ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message());
@@ -395,6 +425,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unu
void
HttpServerConnection::CloseConnection()
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated)
{
return;
@@ -418,6 +450,8 @@ HttpServerConnection::CloseConnection()
void
HttpServerConnection::HandleRequest()
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
if (!m_RequestData.IsKeepAlive())
{
m_RequestState = RequestState::kWritingFinal;
@@ -439,9 +473,29 @@ HttpServerConnection::HandleRequest()
{
ZEN_TRACE_CPU("asio::HandleRequest");
+ const uint32_t RequestNumber = m_RequestCounter.load(std::memory_order_relaxed);
+
HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body());
- ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed));
+ ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
+
+ const HttpVerb RequestVerb = Request.RequestVerb();
+ const std::string_view Uri = Request.RelativeUri();
+
+ if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace))
+ {
+ ZEN_LOG_TRACE(m_Server.m_RequestLog,
+ "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
+ m_ConnectionId,
+ ToString(RequestVerb),
+ Uri,
+ Request.ContentLength(),
+ ToString(Request.RequestContentType()),
+ ToString(Request.AcceptContentType()));
+
+ m_Server.m_RequestTracer.WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber),
+ std::vector<IoBuffer>{Request.ReadPayload()});
+ }
if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
@@ -449,7 +503,15 @@ HttpServerConnection::HandleRequest()
{
Service->HandleRequest(Request);
}
- catch (std::system_error& SystemError)
+ catch (const AssertException& AssertEx)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
+ }
+ catch (const std::system_error& SystemError)
{
// Drop any partially formatted response
Request.m_Response.reset();
@@ -460,23 +522,25 @@ HttpServerConnection::HandleRequest()
}
else
{
- ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what());
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
}
}
- catch (std::bad_alloc& BadAlloc)
+ catch (const std::bad_alloc& BadAlloc)
{
// Drop any partially formatted response
Request.m_Response.reset();
Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
// Drop any partially formatted response
Request.m_Response.reset();
- ZEN_ERROR("Caught exception while handling request: {}", ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
}
}
@@ -490,11 +554,11 @@ HttpServerConnection::HandleRequest()
Response->SuppressPayload();
}
- auto ResponseBuffers = Response->AsioBuffers();
+ const std::vector<asio::const_buffer>& ResponseBuffers = Response->AsioBuffers();
uint64_t ResponseLength = 0;
- for (auto& Buffer : ResponseBuffers)
+ for (const asio::const_buffer& Buffer : ResponseBuffers)
{
ResponseLength += Buffer.size();
}
@@ -573,29 +637,49 @@ struct HttpAcceptor
: m_Server(Server)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
+ , m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4())
{
m_Acceptor.set_option(asio::ip::v6_only(false));
#if ZEN_PLATFORM_WINDOWS
// Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
- typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address;
- m_Acceptor.set_option(excluse_address(true));
+ typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
+ m_Acceptor.set_option(exclusive_address(true));
+ m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
#else // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::socket_base::reuse_address(false));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false));
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
+ m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
asio::ip::address_v6 BindAddress = ForceLoopback ? asio::ip::address_v6::loopback() : asio::ip::address_v6::any();
uint16_t EffectivePort = BasePort;
+ if (BindAddress.is_loopback())
+ {
+ m_Acceptor.set_option(asio::ip::v6_only(true));
+ }
+
asio::error_code BindErrorCode;
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback())
{
// Access denied for a public port - lets try fall back to local port only
BindAddress = asio::ip::address_v6::loopback();
+ m_Acceptor.set_option(asio::ip::v6_only(true));
+ m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode == asio::error::address_in_use)
+ {
+ // Do a retry after a short sleep on same port just to be sure
+ ZEN_INFO("Desired port {} is in use, retrying", BasePort);
+ Sleep(100);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
// Sharing violation implies the port is being used by another process
@@ -613,9 +697,20 @@ struct HttpAcceptor
{
ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
}
- else if (BindAddress.is_loopback())
+ else
{
- ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", "[::1]", EffectivePort);
+ if (EffectivePort != BasePort)
+ {
+ ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort);
+ }
+ if (BindAddress.is_loopback())
+ {
+ m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode);
+ m_UseAlternateProtocolAcceptor = true;
+ ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts",
+ "localhost",
+ EffectivePort);
+ }
}
#if ZEN_PLATFORM_WINDOWS
@@ -635,31 +730,66 @@ struct HttpAcceptor
&OptionNumberOfBytesReturned,
0,
0);
+ if (m_UseAlternateProtocolAcceptor)
+ {
+ NativeSocket = m_AlternateProtocolAcceptor.native_handle();
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+ }
#endif
m_Acceptor.listen();
+ if (m_UseAlternateProtocolAcceptor)
+ {
+ m_AlternateProtocolAcceptor.listen();
+ }
ZEN_INFO("Started asio server at 'http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort);
}
+ ~HttpAcceptor()
+ {
+ m_Acceptor.close();
+ if (m_UseAlternateProtocolAcceptor)
+ {
+ m_AlternateProtocolAcceptor.close();
+ }
+ }
+
void Start()
{
- m_Acceptor.listen();
- InitAccept();
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
+ ZEN_ASSERT(!m_IsStopped);
+ InitAcceptInternal(m_Acceptor);
+ if (m_UseAlternateProtocolAcceptor)
+ {
+ InitAcceptInternal(m_AlternateProtocolAcceptor);
+ }
}
- void Stop() { m_IsStopped = true; }
+ void StopAccepting() { m_IsStopped = true; }
+
+ int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
- void InitAccept()
+private:
+ void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor)
{
auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
- m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
if (Ec)
{
ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- m_Acceptor.local_endpoint().address().to_string(),
- m_Acceptor.local_endpoint().port(),
+ Acceptor.local_endpoint().address().to_string(),
+ Acceptor.local_endpoint().port(),
Ec.message());
}
else
@@ -679,12 +809,12 @@ struct HttpAcceptor
if (!m_IsStopped.load())
{
- InitAccept();
+ InitAcceptInternal(Acceptor);
}
else
{
std::error_code CloseEc;
- m_Acceptor.close(CloseEc);
+ Acceptor.close(CloseEc);
if (CloseEc)
{
ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
@@ -693,12 +823,11 @@ struct HttpAcceptor
});
}
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
-
-private:
HttpAsioServerImpl& m_Server;
asio::io_service& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
+ asio::ip::tcp::acceptor m_AlternateProtocolAcceptor;
+ bool m_UseAlternateProtocolAcceptor{false};
std::atomic<bool> m_IsStopped{false};
};
@@ -787,6 +916,8 @@ HttpAsioServerRequest::ReadPayload()
void
HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(HttpContentType::kBinary));
@@ -798,6 +929,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
void
HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType));
@@ -807,6 +940,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
void
HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType));
@@ -819,6 +954,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
void
HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_ASSERT(!m_Response);
// Not one bit async, innit
@@ -833,7 +970,7 @@ HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges)
//////////////////////////////////////////////////////////////////////////
-HttpAsioServerImpl::HttpAsioServerImpl()
+HttpAsioServerImpl::HttpAsioServerImpl() : m_RequestLog(logging::Get("http_requests"))
{
}
@@ -841,9 +978,17 @@ HttpAsioServerImpl::~HttpAsioServerImpl()
{
}
+void
+HttpAsioServerImpl::Initialize(std::filesystem::path DataDir)
+{
+ m_RequestTracer.Initialize(DataDir);
+}
+
int
HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
ZEN_ASSERT(ThreadCount > 0);
ZEN_INFO("starting asio http with {} service threads", ThreadCount);
@@ -863,13 +1008,19 @@ HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount)
for (int i = 0; i < ThreadCount; ++i)
{
m_ThreadPool.emplace_back([this, Index = i + 1] {
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
SetCurrentThreadName(fmt::format("asio_io_{}", Index));
try
{
m_IoService.run();
}
- catch (std::exception& e)
+ catch (const AssertException& AssertEx)
+ {
+ ZEN_ERROR("Assert caught in asio event loop: {}", AssertEx.FullDescription());
+ }
+ catch (const std::exception& e)
{
ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what());
}
@@ -884,17 +1035,29 @@ HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount)
void
HttpAsioServerImpl::Stop()
{
- m_Acceptor->Stop();
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
+ if (m_Acceptor)
+ {
+ m_Acceptor->StopAccepting();
+ }
m_IoService.stop();
for (auto& Thread : m_ThreadPool)
{
- Thread.join();
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
}
+ m_ThreadPool.clear();
+ m_Acceptor.reset();
}
void
HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
std::string_view UrlPath(InUrlPath);
Service.SetUriPrefixLength(UrlPath.size());
if (!UrlPath.empty() && UrlPath.back() == '/')
@@ -909,6 +1072,8 @@ HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
HttpService*
HttpAsioServerImpl::RouteRequest(std::string_view Url)
{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
RwLock::SharedLockScope _(m_Lock);
HttpService* CandidateService = nullptr;
@@ -978,7 +1143,7 @@ HttpAsioServer::Close()
{
m_Impl->Stop();
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
ZEN_WARN("Caught exception stopping http asio server: {}", ex.what());
}
@@ -994,8 +1159,11 @@ HttpAsioServer::RegisterService(HttpService& Service)
int
HttpAsioServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
- ZEN_UNUSED(DataDir);
+ ZEN_TRACE_CPU("HttpAsioServer::Initialize");
+ m_Impl->Initialize(DataDir);
+
m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), m_ForceLoopback, m_ThreadCount);
+
return m_BasePort;
}
@@ -1052,6 +1220,9 @@ HttpAsioServer::RequestExit()
Ref<HttpServer>
CreateHttpAsioServer(bool ForceLoopback, unsigned int ThreadCount)
{
+ ZEN_TRACE_CPU("CreateHttpAsioServer");
+ ZEN_MEMSCOPE(GetHttpasioTag());
+
return Ref<HttpServer>{new HttpAsioServer(ForceLoopback, ThreadCount)};
}
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 2a6a90d2e..b8b7931a9 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -3,6 +3,7 @@
#include "httpmulti.h"
#include <zencore/logging.h>
+#include <zencore/trace.h>
#if ZEN_PLATFORM_WINDOWS
# include <conio.h>
@@ -30,6 +31,8 @@ HttpMultiServer::RegisterService(HttpService& Service)
int
HttpMultiServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
+ ZEN_TRACE_CPU("HttpMultiServer::Initialize");
+
ZEN_UNUSED(DataDir);
ZEN_ASSERT(!m_IsInitialized);
@@ -103,6 +106,10 @@ HttpMultiServer::RequestExit()
void
HttpMultiServer::Close()
{
+ for (auto& Server : m_Servers)
+ {
+ Server->Close();
+ }
}
void
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index c64134c95..93094e21b 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -6,6 +6,8 @@
#include <zencore/logging.h>
#include <zencore/string.h>
+#include <limits>
+
namespace zen {
using namespace std::literals;
@@ -69,23 +71,21 @@ HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize)
int
HttpRequestParser::OnUrl(const char* Data, size_t Bytes)
{
- if (!m_Url)
+ const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size();
+ if (RemainingBufferSpace < Bytes)
{
- ZEN_ASSERT_SLOW(m_UrlLength == 0);
- m_Url = m_HeaderCursor;
+ ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
}
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
-
- if (RemainingBufferSpace < Bytes)
+ if (m_UrlRange.Length == 0)
{
- ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ ZEN_ASSERT_SLOW(m_UrlRange.Offset == 0);
+ m_UrlRange.Offset = (uint32_t)m_HeaderData.size();
}
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_UrlLength += Bytes;
+ m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]);
+ m_UrlRange.Length += (uint32_t)Bytes;
return 0;
}
@@ -93,56 +93,70 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes)
int
HttpRequestParser::OnHeader(const char* Data, size_t Bytes)
{
- if (m_CurrentHeaderValueLength)
+ const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size();
+ if (RemainingBufferSpace < Bytes)
{
- AppendCurrentHeader();
+ ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
- m_CurrentHeaderNameLength = 0;
- m_CurrentHeaderValueLength = 0;
- m_CurrentHeaderName = m_HeaderCursor;
+ if (m_HeaderEntries.empty())
+ {
+ m_HeaderEntries.resize(1);
}
- else if (m_CurrentHeaderName == nullptr)
+ HeaderEntry* CurrentHeaderEntry = &m_HeaderEntries.back();
+ if (CurrentHeaderEntry->ValueRange.Length)
{
- m_CurrentHeaderName = m_HeaderCursor;
+ ParseCurrentHeader();
+ m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}});
+ CurrentHeaderEntry = &m_HeaderEntries.back();
}
-
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
- if (RemainingBufferSpace < Bytes)
+ else if (CurrentHeaderEntry->NameRange.Length == 0)
{
- ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}});
+ CurrentHeaderEntry = &m_HeaderEntries.back();
}
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_CurrentHeaderNameLength += Bytes;
+ m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]);
+ CurrentHeaderEntry->NameRange.Length += (uint32_t)Bytes;
return 0;
}
void
-HttpRequestParser::AppendCurrentHeader()
+HttpRequestParser::ParseCurrentHeader()
{
- std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength);
- std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength);
+ ZEN_ASSERT_SLOW(!m_HeaderEntries.empty());
+ const HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back();
+ const size_t CurrentHeaderCount = m_HeaderEntries.size();
+ const std::string_view HeaderName(GetHeaderSubString(CurrentHeaderEntry.NameRange));
+ if (CurrentHeaderCount > std::numeric_limits<int8_t>::max())
+ {
+ ZEN_WARN("HttpRequestParser parser only supports up to {} headers, can't store header '{}'. Dropping it.",
+ std::numeric_limits<int8_t>::max(),
+ HeaderName);
+ return;
+ }
+ const std::string_view HeaderValue(GetHeaderSubString(CurrentHeaderEntry.ValueRange));
- const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
+ const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
+ const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1);
if (HeaderHash == HashContentLength)
{
- m_ContentLengthHeaderIndex = (int8_t)m_Headers.size();
+ m_ContentLengthHeaderIndex = CurrentHeaderIndex;
}
else if (HeaderHash == HashAccept)
{
- m_AcceptHeaderIndex = (int8_t)m_Headers.size();
+ m_AcceptHeaderIndex = CurrentHeaderIndex;
}
else if (HeaderHash == HashContentType)
{
- m_ContentTypeHeaderIndex = (int8_t)m_Headers.size();
+ m_ContentTypeHeaderIndex = CurrentHeaderIndex;
}
else if (HeaderHash == HashSession)
{
- m_SessionId = Oid::FromHexString(HeaderValue);
+ m_SessionId = Oid::TryFromHexString(HeaderValue);
}
else if (HeaderHash == HashRequest)
{
@@ -162,38 +176,38 @@ HttpRequestParser::AppendCurrentHeader()
}
else if (HeaderHash == HashRange)
{
- m_RangeHeaderIndex = (int8_t)m_Headers.size();
+ m_RangeHeaderIndex = CurrentHeaderIndex;
}
-
- m_Headers.emplace_back(HeaderName, HeaderValue);
}
int
HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes)
{
- if (m_CurrentHeaderValueLength == 0)
- {
- m_CurrentHeaderValue = m_HeaderCursor;
- }
-
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size();
if (RemainingBufferSpace < Bytes)
{
- ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace);
+ ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
return 1;
}
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_CurrentHeaderValueLength += Bytes;
+ ZEN_ASSERT_SLOW(!m_HeaderEntries.empty());
+ HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back();
+ if (CurrentHeaderEntry.ValueRange.Length == 0)
+ {
+ CurrentHeaderEntry.ValueRange.Offset = (uint32_t)m_HeaderData.size();
+ }
+ m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]);
+ CurrentHeaderEntry.ValueRange.Length += (uint32_t)Bytes;
return 0;
}
static void
-NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl)
+NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl)
{
- bool LastCharWasSeparator = false;
+ bool LastCharWasSeparator = false;
+ const char* Url = InUrl.data();
+ const size_t UrlLength = InUrl.length();
for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex)
{
const char UrlChar = Url[UrlIndex];
@@ -226,9 +240,13 @@ HttpRequestParser::OnHeadersComplete()
{
try
{
- if (m_CurrentHeaderValueLength)
+ if (!m_HeaderEntries.empty())
{
- AppendCurrentHeader();
+ HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back();
+ if (CurrentHeaderEntry.NameRange.Length)
+ {
+ ParseCurrentHeader();
+ }
}
m_KeepAlive = !!http_should_keep_alive(&m_Parser);
@@ -268,21 +286,21 @@ HttpRequestParser::OnHeadersComplete()
break;
}
- std::string_view Url(m_Url, m_UrlLength);
+ std::string_view FullUrl(GetHeaderSubString(m_UrlRange));
- if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos)
+ if (auto QuerySplit = FullUrl.find_first_of('?'); QuerySplit != std::string_view::npos)
{
- m_UrlLength = QuerySplit;
- m_QueryString = m_Url + QuerySplit + 1;
- m_QueryLength = Url.size() - QuerySplit - 1;
+ m_UrlRange.Length = uint32_t(QuerySplit);
+ m_QueryStringRange = {.Offset = uint32_t(m_UrlRange.Offset + QuerySplit + 1),
+ .Length = uint32_t(FullUrl.size() - QuerySplit - 1)};
}
- NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl);
+ NormalizeUrlPath(FullUrl, m_NormalizedUrl);
- if (m_ContentLengthHeaderIndex >= 0)
+ std::string_view Value = GetHeaderValue(m_ContentLengthHeaderIndex);
+ if (!Value.empty())
{
- std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value;
- uint64_t ContentLength = 0;
+ uint64_t ContentLength = 0;
std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength);
if (ContentLength)
@@ -330,16 +348,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
void
HttpRequestParser::ResetState()
{
- m_HeaderCursor = m_HeaderBuffer;
- m_CurrentHeaderName = nullptr;
- m_CurrentHeaderNameLength = 0;
- m_CurrentHeaderValue = nullptr;
- m_CurrentHeaderValueLength = 0;
- m_CurrentHeaderName = nullptr;
- m_Url = nullptr;
- m_UrlLength = 0;
- m_QueryString = nullptr;
- m_QueryLength = 0;
+ m_UrlRange = {};
+ m_QueryStringRange = {};
+
+ m_HeaderEntries.clear();
+
m_ContentLengthHeaderIndex = -1;
m_AcceptHeaderIndex = -1;
m_ContentTypeHeaderIndex = -1;
@@ -347,7 +360,8 @@ HttpRequestParser::ResetState()
m_Expect100Continue = false;
m_BodyBuffer = {};
m_BodyPosition = 0;
- m_Headers.clear();
+
+ m_HeaderData.clear();
m_NormalizedUrl.clear();
}
@@ -366,7 +380,12 @@ HttpRequestParser::OnMessageComplete()
ResetState();
return 0;
}
- catch (std::system_error& SystemError)
+ catch (const AssertException& AssertEx)
+ {
+ ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription());
+ return 1;
+ }
+ catch (const std::system_error& SystemError)
{
if (IsOOM(SystemError.code()))
{
@@ -378,18 +397,18 @@ HttpRequestParser::OnMessageComplete()
}
else
{
- ZEN_ERROR("failed processing http request: '{}'", SystemError.what());
+ ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value());
}
ResetState();
return 1;
}
- catch (std::bad_alloc& BadAlloc)
+ catch (const std::bad_alloc& BadAlloc)
{
ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what());
ResetState();
return 1;
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
ZEN_ERROR("failed processing http request: '{}'", Ex.what());
ResetState();
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index bdbcab4d9..0d2664ec5 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -5,6 +5,8 @@
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <EASTL/fixed_vector.h>
+
ZEN_THIRD_PARTY_INCLUDES_START
#include <http_parser.h>
ZEN_THIRD_PARTY_INCLUDES_END
@@ -31,73 +33,68 @@ struct HttpRequestParser
HttpVerb RequestVerb() const { return m_RequestVerb; }
bool IsKeepAlive() const { return m_KeepAlive; }
- std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; }
- std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); }
+ std::string_view Url() const { return m_NormalizedUrl.empty() ? GetHeaderSubString(m_UrlRange) : m_NormalizedUrl; }
+ std::string_view QueryString() const { return GetHeaderSubString(m_QueryStringRange); }
IoBuffer Body() { return m_BodyBuffer; }
- inline HttpContentType ContentType()
- {
- if (m_ContentTypeHeaderIndex < 0)
- {
- return HttpContentType::kUnknownContentType;
- }
-
- return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value);
- }
+ inline HttpContentType ContentType() { return ParseContentType(GetHeaderValue(m_ContentTypeHeaderIndex)); }
- inline HttpContentType AcceptType()
- {
- if (m_AcceptHeaderIndex < 0)
- {
- return HttpContentType::kUnknownContentType;
- }
-
- return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value);
- }
+ inline HttpContentType AcceptType() { return ParseContentType(GetHeaderValue(m_AcceptHeaderIndex)); }
Oid SessionId() const { return m_SessionId; }
int RequestId() const { return m_RequestId; }
- std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); }
+ std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); }
private:
+ struct HeaderRange
+ {
+ uint32_t Offset = 0;
+ uint32_t Length = 0;
+ };
+
struct HeaderEntry
{
- HeaderEntry() = default;
+ HeaderRange NameRange;
+ HeaderRange ValueRange;
+ };
- HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {}
+ inline std::string_view GetHeaderValue(int8_t HeaderIndex) const
+ {
+ if (HeaderIndex == -1)
+ {
+ return {};
+ }
+ ZEN_ASSERT(size_t(HeaderIndex) < m_HeaderEntries.size());
+ return GetHeaderSubString(m_HeaderEntries[HeaderIndex].ValueRange);
+ }
- std::string_view Name;
- std::string_view Value;
- };
+ std::string_view GetHeaderSubString(const HeaderRange& Range) const
+ {
+ ZEN_ASSERT_SLOW(Range.Offset + Range.Length <= m_HeaderData.size());
+ return std::string_view(m_HeaderData.begin(), m_HeaderData.size()).substr(Range.Offset, Range.Length);
+ }
- HttpRequestParserCallbacks& m_Connection;
- char* m_HeaderCursor = m_HeaderBuffer;
- char* m_Url = nullptr;
- size_t m_UrlLength = 0;
- char* m_QueryString = nullptr;
- size_t m_QueryLength = 0;
- char* m_CurrentHeaderName = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderNameLength = 0;
- char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderValueLength = 0;
- std::vector<HeaderEntry> m_Headers;
- int8_t m_ContentLengthHeaderIndex;
- int8_t m_AcceptHeaderIndex;
- int8_t m_ContentTypeHeaderIndex;
- int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
- std::atomic_bool m_KeepAlive{false};
- bool m_Expect100Continue = false;
- int m_RequestId = -1;
- Oid m_SessionId{};
- IoBuffer m_BodyBuffer;
- uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
- char m_HeaderBuffer[1024];
- std::string m_NormalizedUrl;
-
- void AppendCurrentHeader();
+ HttpRequestParserCallbacks& m_Connection;
+ HeaderRange m_UrlRange;
+ HeaderRange m_QueryStringRange;
+ eastl::fixed_vector<HeaderEntry, 16> m_HeaderEntries;
+ int8_t m_ContentLengthHeaderIndex;
+ int8_t m_AcceptHeaderIndex;
+ int8_t m_ContentTypeHeaderIndex;
+ int8_t m_RangeHeaderIndex;
+ HttpVerb m_RequestVerb;
+ std::atomic_bool m_KeepAlive{false};
+ bool m_Expect100Continue = false;
+ int m_RequestId = -1;
+ Oid m_SessionId{};
+ IoBuffer m_BodyBuffer;
+ uint64_t m_BodyPosition = 0;
+ http_parser m_Parser;
+ eastl::fixed_vector<char, 512> m_HeaderData;
+ std::string m_NormalizedUrl;
+
+ void ParseCurrentHeader();
int OnMessageBegin();
int OnUrl(const char* Data, size_t Bytes);
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 3eed9db8f..d6ca7e1c5 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -2,6 +2,8 @@
#include <zenhttp/httpplugin.h>
+#include "httptracer.h"
+
#if ZEN_WITH_PLUGINS
# include "httpparser.h"
@@ -10,6 +12,7 @@
# include <zencore/filesystem.h>
# include <zencore/fmtutils.h>
# include <zencore/logging.h>
+# include <zencore/memory/llm.h>
# include <zencore/scopeguard.h>
# include <zencore/session.h>
# include <zencore/thread.h>
@@ -25,14 +28,6 @@
# include <conio.h>
# endif
-# define PLUGIN_VERBOSE_TRACE 1
-
-# if PLUGIN_VERBOSE_TRACE
-# define ZEN_TRACE_VERBOSE ZEN_TRACE
-# else
-# define ZEN_TRACE_VERBOSE(fmtstr, ...)
-# endif
-
namespace zen {
struct HttpPluginServerImpl;
@@ -40,6 +35,14 @@ struct HttpPluginResponse;
using namespace std::literals;
+const FLLMTag&
+GetHttppluginTag()
+{
+ static FLLMTag _("httpplugin");
+
+ return _;
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted
@@ -103,8 +106,6 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* RouteRequest(std::string_view Url);
- void WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload);
-
struct ServiceEntry
{
std::string ServiceUrlPath;
@@ -119,8 +120,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- std::filesystem::path m_DataDir; // Application data directory
- std::filesystem::path m_PayloadDir; // Request debugging payload directory
+
+ HttpServerTracer m_RequestTracer;
// TransportServer
@@ -191,6 +192,7 @@ private:
void
HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
ZEN_TRACE_CPU("http_plugin::InitializeForPayload");
m_ResponseCode = ResponseCode;
@@ -232,6 +234,8 @@ HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuff
std::string_view
HttpPluginResponse::GetHeaders()
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
if (m_Headers.Size() == 0)
{
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
@@ -266,6 +270,8 @@ HttpPluginConnectionHandler::~HttpPluginConnectionHandler()
void
HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server, uint32_t ConnectionId)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
m_TransportConnection = Transport;
m_Server = &Server;
m_ConnectionId = ConnectionId;
@@ -298,6 +304,8 @@ HttpPluginConnectionHandler::Release() const
void
HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
ZEN_ASSERT(m_Server);
ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} OnBytesRead: {}", m_ConnectionId, AvailableBytes);
@@ -325,6 +333,8 @@ HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableByt
void
HttpPluginConnectionHandler::HandleRequest()
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
ZEN_ASSERT(m_Server);
const uint32_t RequestNumber = m_RequestCounter.fetch_add(1);
@@ -376,8 +386,8 @@ HttpPluginConnectionHandler::HandleRequest()
ToString(Request.RequestContentType()),
ToString(Request.AcceptContentType()));
- m_Server->WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber),
- std::vector<IoBuffer>{Request.ReadPayload()});
+ m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber),
+ std::vector<IoBuffer>{Request.ReadPayload()});
}
if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
@@ -386,7 +396,15 @@ HttpPluginConnectionHandler::HandleRequest()
{
Service->HandleRequest(Request);
}
- catch (std::system_error& SystemError)
+ catch (const AssertException& AssertEx)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
+ }
+ catch (const std::system_error& SystemError)
{
// Drop any partially formatted response
Request.m_Response.reset();
@@ -397,23 +415,25 @@ HttpPluginConnectionHandler::HandleRequest()
}
else
{
- ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what());
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
}
}
- catch (std::bad_alloc& BadAlloc)
+ catch (const std::bad_alloc& BadAlloc)
{
// Drop any partially formatted response
Request.m_Response.reset();
Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
// Drop any partially formatted response
Request.m_Response.reset();
- ZEN_ERROR("Caught exception while handling request: {}", ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
}
}
@@ -442,7 +462,8 @@ HttpPluginConnectionHandler::HandleRequest()
if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
{
- m_Server->WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers);
+ m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
+ ResponseBuffers);
}
for (const IoBuffer& Buffer : ResponseBuffers)
@@ -523,6 +544,7 @@ HttpPluginConnectionHandler::HandleRequest()
void
HttpPluginConnectionHandler::TerminateConnection()
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
ZEN_ASSERT(m_TransportConnection);
m_TransportConnection->CloseConnection();
}
@@ -533,6 +555,8 @@ HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, Htt
: m_Request(Request)
, m_PayloadBuffer(std::move(PayloadBuffer))
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
const int PrefixLength = Service.UriPrefixLength();
std::string_view Uri = Request.Url();
@@ -613,6 +637,7 @@ void
HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode)
{
ZEN_ASSERT(!m_Response);
+ ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary));
std::array<IoBuffer, 0> Empty;
@@ -624,6 +649,7 @@ void
HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
{
ZEN_ASSERT(!m_Response);
+ ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(ContentType));
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
@@ -633,6 +659,8 @@ void
HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
{
ZEN_ASSERT(!m_Response);
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
m_Response.reset(new HttpPluginResponse(ContentType));
IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
@@ -645,6 +673,7 @@ void
HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
{
ZEN_ASSERT(!m_Response);
+ ZEN_MEMSCOPE(GetHttppluginTag());
// Not one bit async, innit
ContinuationHandler(*this);
@@ -669,6 +698,7 @@ HttpPluginServerImpl::~HttpPluginServerImpl()
TransportServerConnection*
HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()};
const uint32_t ConnectionId = m_ConnectionIdCounter.fetch_add(1);
Handler->Initialize(Connection, *this, ConnectionId);
@@ -678,10 +708,10 @@ HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection)
int
HttpPluginServerImpl::Initialize(int BasePort, std::filesystem::path DataDir)
{
- m_DataDir = DataDir;
- m_PayloadDir = DataDir / "debug" / GetSessionIdString();
+ ZEN_TRACE_CPU("HttpPluginServerImpl::Initialize");
- ZEN_INFO("any debug payloads will be written to '{}'", m_PayloadDir);
+ ZEN_MEMSCOPE(GetHttppluginTag());
+ m_RequestTracer.Initialize(DataDir);
try
{
@@ -693,13 +723,13 @@ HttpPluginServerImpl::Initialize(int BasePort, std::filesystem::path DataDir)
{
Plugin->Initialize(this);
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
ZEN_WARN("exception caught during plugin initialization: {}", Ex.what());
}
}
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
ZEN_WARN("Caught exception starting http plugin server: {}", ex.what());
}
@@ -715,6 +745,8 @@ HttpPluginServerImpl::Close()
if (!m_IsInitialized)
return;
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
try
{
RwLock::ExclusiveLockScope _(m_Lock);
@@ -725,7 +757,7 @@ HttpPluginServerImpl::Close()
{
Plugin->Shutdown();
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what());
}
@@ -735,7 +767,7 @@ HttpPluginServerImpl::Close()
m_Plugins.clear();
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what());
}
@@ -746,6 +778,8 @@ HttpPluginServerImpl::Close()
void
HttpPluginServerImpl::Run(bool IsInteractive)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
const bool TestMode = !IsInteractive;
int WaitTimeout = -1;
@@ -796,6 +830,8 @@ HttpPluginServerImpl::RequestExit()
void
HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
RwLock::ExclusiveLockScope _(m_Lock);
m_Plugins.emplace_back(std::move(Plugin));
}
@@ -803,6 +839,8 @@ HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin)
void
HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
RwLock::ExclusiveLockScope _(m_Lock);
auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin);
if (It != m_Plugins.end())
@@ -814,6 +852,8 @@ HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin)
void
HttpPluginServerImpl::RegisterService(HttpService& Service)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
std::string_view UrlPath(Service.BaseUri());
Service.SetUriPrefixLength(UrlPath.size());
@@ -829,6 +869,8 @@ HttpPluginServerImpl::RegisterService(HttpService& Service)
HttpService*
HttpPluginServerImpl::RouteRequest(std::string_view Url)
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
+
RwLock::SharedLockScope _(m_Lock);
HttpService* CandidateService = nullptr;
@@ -848,23 +890,6 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
-void
-HttpPluginServerImpl::WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload)
-{
- uint64_t PayloadSize = 0;
- std::vector<const IoBuffer*> Buffers;
- for (auto& Io : Payload)
- {
- Buffers.push_back(&Io);
- PayloadSize += Io.GetSize();
- }
-
- if (PayloadSize)
- {
- WriteFile(m_PayloadDir / Filename, Buffers.data(), Buffers.size());
- }
-}
-
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;
@@ -872,6 +897,7 @@ struct HttpPluginServerImpl;
Ref<HttpPluginServer>
CreateHttpPluginServer()
{
+ ZEN_MEMSCOPE(GetHttppluginTag());
return Ref<HttpPluginServer>(new HttpPluginServerImpl);
}
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 5cd273c40..95d83911d 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -9,11 +9,14 @@
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/memory/llm.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
#include <zencore/timer.h>
#include <zencore/trace.h>
-#include <zenutil/packageformat.h>
+#include <zenhttp/packageformat.h>
+
+#include <EASTL/fixed_vector.h>
#if ZEN_WITH_HTTPSYS
# define _WINSOCKAPI_
@@ -25,6 +28,14 @@
namespace zen {
+const FLLMTag&
+GetHttpsysTag()
+{
+ static FLLMTag HttpsysTag("httpsys");
+
+ return HttpsysTag;
+}
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -372,14 +383,14 @@ public:
void SuppressResponseBody(); // typically used for HEAD requests
private:
- std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks;
- uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
- uint16_t m_ResponseCode = 0;
- uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists
- uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
- bool m_IsInitialResponse = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- std::vector<IoBuffer> m_DataBuffers;
+ eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
+ uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
+ uint16_t m_ResponseCode = 0;
+ uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists
+ uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
+ bool m_IsInitialResponse = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
};
@@ -524,7 +535,14 @@ HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfB
if (IoResult != NO_ERROR)
{
- ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult));
+ ZEN_WARN("response '{}' ({}) aborted after transfering '{}', {} out of {} bytes, reason: {} ({})",
+ ReasonStringForHttpResultCode(m_ResponseCode),
+ m_ResponseCode,
+ ToString(m_ContentType),
+ NumberOfBytesTransferred,
+ m_TotalDataSize,
+ GetSystemErrorAsString(IoResult),
+ IoResult);
// if one transmit failed there's really no need to go on
return nullptr;
@@ -673,7 +691,7 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
);
}
- auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void {
+ auto EmitResponseDetails = [&](StringBuilderBase& ResponseDetails) -> void {
for (int i = 0; i < ThisRequestChunkCount; ++i)
{
const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i];
@@ -756,7 +774,7 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
// Emit diagnostics
ExtendableStringBuilder<256> ResponseDetails;
- EmitReponseDetails(ResponseDetails);
+ EmitResponseDetails(ResponseDetails);
ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}",
SendResult,
@@ -817,7 +835,7 @@ HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode)
ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest");
ErrorCode.clear();
- Transaction().Server().WorkPool().ScheduleWork(m_WorkItem);
+ Transaction().Server().WorkPool().ScheduleWork(m_WorkItem, WorkerThreadPool::EMode::EnableBacklog);
}
HttpSysRequestHandler*
@@ -836,6 +854,8 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr
void
HttpAsyncWorkRequest::AsyncWorkItem::Execute()
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
ZEN_TRACE_CPU("httpsys::async_execute");
try
@@ -873,10 +893,15 @@ HttpAsyncWorkRequest::AsyncWorkItem::Execute()
new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv));
}
}
- catch (std::exception& Ex)
+ catch (const AssertException& AssertEx)
{
return (void)Tx.IssueNextRequest(
- new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what())));
+ new HttpMessageResponseRequest(Tx, 500, fmt::format("Assert thrown in async work: '{}", AssertEx.FullDescription())));
+ }
+ catch (const std::exception& Ex)
+ {
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: {}", Ex.what())));
}
}
@@ -896,6 +921,8 @@ HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig)
, m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled)
, m_InitialConfig(InConfig)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
// Initialize thread pool
int MinThreadCount;
@@ -971,6 +998,8 @@ HttpSysServer::Close()
int
HttpSysServer::InitializeServer(int BasePort)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
using namespace std::literals;
WideStringBuilder<64> WildcardUrlPath;
@@ -1215,6 +1244,8 @@ HttpSysServer::Cleanup()
WorkerThreadPool&
HttpSysServer::WorkPool()
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
if (!m_AsyncWorkPool)
{
RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock);
@@ -1299,6 +1330,8 @@ HttpSysServer::IssueNewRequestMaybe()
return;
}
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this);
std::error_code ErrorCode;
@@ -1322,6 +1355,8 @@ HttpSysServer::IssueNewRequestMaybe()
void
HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
if (UrlPath[0] == '/')
{
++UrlPath;
@@ -1483,11 +1518,15 @@ HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler
return true;
}
- ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message());
+ ZEN_WARN("IssueRequest() failed: {}", ErrorCode.message());
+ }
+ catch (const AssertException& AssertEx)
+ {
+ ZEN_ERROR("Assert thrown in IssueNextRequest(): {}", AssertEx.FullDescription());
}
- catch (std::exception& Ex)
+ catch (const std::exception& Ex)
{
- ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what());
+ ZEN_ERROR("exception caught in IssueNextRequest(): {}", Ex.what());
}
// something went wrong, no request is pending
@@ -1659,7 +1698,7 @@ HttpSysServerRequest::ParseSessionId() const
{
if (Header.RawValueLength == Oid::StringLength)
{
- return Oid::FromHexString({Header.pRawValue, Header.RawValueLength});
+ return Oid::TryFromHexString({Header.pRawValue, Header.RawValueLength});
}
}
}
@@ -1698,6 +1737,8 @@ HttpSysServerRequest::ReadPayload()
void
HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
ZEN_ASSERT(IsHandled() == false);
auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
@@ -1715,6 +1756,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
void
HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
ZEN_ASSERT(IsHandled() == false);
auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
@@ -1732,6 +1775,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
void
HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
ZEN_ASSERT(IsHandled() == false);
auto Response =
@@ -1750,6 +1795,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
void
HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
if (m_HttpTx.Server().IsAsyncResponseEnabled())
{
m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler));
@@ -1826,7 +1873,17 @@ InitialRequestHandler::IssueRequest(std::error_code& ErrorCode)
ErrorCode = MakeErrorCode(HttpApiResult);
- ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message());
+ if (IsInitialRequest())
+ {
+ ZEN_WARN("initial HttpReceiveHttpRequest failed, error: {}", ErrorCode.message());
+ }
+ else
+ {
+ ZEN_WARN("HttpReceiveHttpRequest (offset: {}, content-length: {}) failed, error: {}",
+ m_CurrentPayloadOffset,
+ m_PayloadBuffer.GetSize(),
+ ErrorCode.message());
+ }
return;
}
@@ -1837,6 +1894,8 @@ InitialRequestHandler::IssueRequest(std::error_code& ErrorCode)
HttpSysRequestHandler*
InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
auto _ = MakeGuard([&] { m_IsInitialRequest = false; });
switch (IoResult)
@@ -1985,23 +2044,28 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
// Unable to route
return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
}
- catch (std::system_error& SystemError)
+ catch (const AssertException& AssertEx)
+ {
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, AssertEx.FullDescription());
+ }
+ catch (const std::system_error& SystemError)
{
if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
{
return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what());
}
- ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what());
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})", SystemError.what(), SystemError.code().value());
return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what());
}
- catch (std::bad_alloc& BadAlloc)
+ catch (const std::bad_alloc& BadAlloc)
{
return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what());
}
- catch (std::exception& ex)
+ catch (const std::exception& ex)
{
- ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+ ZEN_WARN("Caught exception while handling request: '{}'", ex.what());
return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what());
}
}
@@ -2014,6 +2078,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
int
HttpSysServer::Initialize(int BasePort, std::filesystem::path DataDir)
{
+ ZEN_TRACE_CPU("HttpSysServer::Initialize");
+
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
@@ -2042,6 +2108,9 @@ HttpSysServer::RegisterService(HttpService& Service)
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{
+ ZEN_TRACE_CPU("CreateHttpSysServer");
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
return Ref<HttpServer>(new HttpSysServer(Config));
}
diff --git a/src/zenhttp/servers/httptracer.cpp b/src/zenhttp/servers/httptracer.cpp
new file mode 100644
index 000000000..483307fb1
--- /dev/null
+++ b/src/zenhttp/servers/httptracer.cpp
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httptracer.h"
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+
+namespace zen {
+
+void
+HttpServerTracer::Initialize(std::filesystem::path DataDir)
+{
+ m_DataDir = DataDir;
+ m_PayloadDir = DataDir / "debug" / GetSessionIdString();
+
+ ZEN_INFO("any debug payloads will be written to '{}'", m_PayloadDir);
+}
+
+void
+HttpServerTracer::WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload)
+{
+ uint64_t PayloadSize = 0;
+ std::vector<const IoBuffer*> Buffers;
+ for (auto& Io : Payload)
+ {
+ Buffers.push_back(&Io);
+ PayloadSize += Io.GetSize();
+ }
+
+ if (PayloadSize)
+ {
+ WriteFile(m_PayloadDir / Filename, Buffers.data(), Buffers.size());
+ }
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h
new file mode 100644
index 000000000..da72c79c9
--- /dev/null
+++ b/src/zenhttp/servers/httptracer.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpserver.h>
+
+#pragma once
+
+namespace zen {
+
+/** Helper class for HTTP server implementations
+
+ Provides some common functionality which can be used across all server
+ implementations. These could be in the root class but I think it's nicer
+ to hide the implementation details from client code
+ */
+class HttpServerTracer
+{
+public:
+ void Initialize(std::filesystem::path DataDir);
+ void WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload);
+
+private:
+ std::filesystem::path m_DataDir; // Application data directory
+ std::filesystem::path m_PayloadDir; // Request debugging payload directory
+};
+
+} // namespace zen
diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp
index a9a782821..96a15518c 100644
--- a/src/zenhttp/transports/asiotransport.cpp
+++ b/src/zenhttp/transports/asiotransport.cpp
@@ -426,7 +426,7 @@ AsioTransportPlugin::Initialize(TransportServer* ServerInterface)
{
m_IoService.run();
}
- catch (std::exception& e)
+ catch (const std::exception& e)
{
ZEN_ERROR("exception caught in asio event loop: {}", e.what());
}
diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp
index e09e62ec5..fb3dd23b5 100644
--- a/src/zenhttp/transports/dlltransport.cpp
+++ b/src/zenhttp/transports/dlltransport.cpp
@@ -21,18 +21,31 @@ namespace zen {
//////////////////////////////////////////////////////////////////////////
+class DllTransportLogger : public TransportLogger, public RefCounted
+{
+public:
+ DllTransportLogger(std::string_view PluginName);
+ virtual ~DllTransportLogger() = default;
+
+ void LogMessage(LogLevel Level, const char* Message) override;
+
+private:
+ std::string m_PluginName;
+};
+
struct LoadedDll
{
std::string Name;
std::filesystem::path LoadedFromPath;
+ DllTransportLogger* Logger = nullptr;
Ref<TransportPlugin> Plugin;
};
class DllTransportPluginImpl : public DllTransportPlugin, RefCounted
{
public:
- DllTransportPluginImpl();
- ~DllTransportPluginImpl();
+ DllTransportPluginImpl() = default;
+ ~DllTransportPluginImpl() = default;
virtual uint32_t AddRef() const override;
virtual uint32_t Release() const override;
@@ -42,7 +55,7 @@ public:
virtual const char* GetDebugName() override;
virtual bool IsAvailable() override;
- virtual void LoadDll(std::string_view Name) override;
+ virtual bool LoadDll(std::string_view Name) override;
virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) override;
private:
@@ -51,12 +64,27 @@ private:
std::vector<LoadedDll> m_Transports;
};
-DllTransportPluginImpl::DllTransportPluginImpl()
+DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginName(PluginName)
{
}
-DllTransportPluginImpl::~DllTransportPluginImpl()
+void
+DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message)
{
+ logging::level::LogLevel Level;
+ // clang-format off
+ switch (PluginLogLevel)
+ {
+ case LogLevel::Trace: Level = logging::level::Trace; break;
+ case LogLevel::Debug: Level = logging::level::Debug; break;
+ case LogLevel::Info: Level = logging::level::Info; break;
+ case LogLevel::Warn: Level = logging::level::Warn; break;
+ case LogLevel::Err: Level = logging::level::Err; break;
+ case LogLevel::Critical: Level = logging::level::Critical; break;
+ default: Level = logging::level::Off; break;
+ }
+ // clang-format on
+ ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message)
}
uint32_t
@@ -109,6 +137,7 @@ DllTransportPluginImpl::Shutdown()
try
{
Transport.Plugin->Shutdown();
+ Transport.Logger->Release();
}
catch (const std::exception&)
{
@@ -143,42 +172,73 @@ DllTransportPluginImpl::ConfigureDll(std::string_view Name, const char* OptionTa
}
}
-void
+bool
DllTransportPluginImpl::LoadDll(std::string_view Name)
{
RwLock::ExclusiveLockScope _(m_Lock);
- ExtendableStringBuilder<128> DllPath;
- DllPath << Name << ".dll";
+ ExtendableStringBuilder<1024> DllPath;
+ DllPath << Name;
+ if (!Name.ends_with(".dll"))
+ {
+ DllPath << ".dll";
+ }
+
+ std::string FileName = std::filesystem::path(DllPath.c_str()).filename().replace_extension().string();
+
HMODULE DllHandle = LoadLibraryA(DllPath.c_str());
if (!DllHandle)
{
- std::error_code Ec = MakeErrorCodeFromLastError();
-
- throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath));
+ ZEN_WARN("Failed to load transport DLL from '{}' due to '{}'", DllPath, GetLastErrorAsString())
+ return false;
}
- TransportPlugin* CreateTransportPlugin();
+ PfnGetTransportPluginVersion GetVersion = (PfnGetTransportPluginVersion)GetProcAddress(DllHandle, "GetTransportPluginVersion");
+ PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin");
+
+ uint32_t APIVersion = 0;
+ uint32_t PluginVersion = 0;
+
+ if (GetVersion)
+ {
+ GetVersion(&APIVersion, &PluginVersion);
+ }
- PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin");
+ const bool bValidApiVersion = APIVersion == kTransportApiVersion;
- if (!CreatePlugin)
+ if (!GetVersion || !CreatePlugin || !bValidApiVersion)
{
std::error_code Ec = MakeErrorCodeFromLastError();
FreeLibrary(DllHandle);
- throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath));
+ if (GetVersion && !bValidApiVersion)
+ {
+ ZEN_WARN("Failed to load transport DLL from '{}' due to invalid API version {}, supported API version is {}",
+ DllPath,
+ APIVersion,
+ kTransportApiVersion)
+ }
+ else
+ {
+ ZEN_WARN("Failed to load transport DLL from '{}' due to not finding GetTransportPluginVersion or CreateTransportPlugin",
+ DllPath)
+ }
+
+ return false;
}
LoadedDll NewDll;
NewDll.Name = Name;
NewDll.LoadedFromPath = DllPath.c_str();
- NewDll.Plugin = CreatePlugin();
+ NewDll.Logger = new DllTransportLogger(FileName);
+ NewDll.Logger->AddRef();
+ NewDll.Plugin = CreatePlugin(NewDll.Logger);
m_Transports.emplace_back(std::move(NewDll));
+ return true;
}
DllTransportPlugin*
diff --git a/src/zenhttp/transports/dlltransport.h b/src/zenhttp/transports/dlltransport.h
index 9346a10ce..c49f888da 100644
--- a/src/zenhttp/transports/dlltransport.h
+++ b/src/zenhttp/transports/dlltransport.h
@@ -15,7 +15,7 @@ namespace zen {
class DllTransportPlugin : public TransportPlugin
{
public:
- virtual void LoadDll(std::string_view Name) = 0;
+ virtual bool LoadDll(std::string_view Name) = 0;
virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) = 0;
};
diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp
index 7407c55dd..c06a50c95 100644
--- a/src/zenhttp/transports/winsocktransport.cpp
+++ b/src/zenhttp/transports/winsocktransport.cpp
@@ -304,18 +304,20 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface)
TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)};
Connection->Initialize(ConnectionInterface, ClientSocket);
- m_WorkerThreadpool->ScheduleWork([Connection] {
- try
- {
- Connection->HandleConnection();
- }
- catch (std::exception& Ex)
- {
- ZEN_WARN("exception caught in connection loop: {}", Ex.what());
- }
-
- delete Connection;
- });
+ m_WorkerThreadpool->ScheduleWork(
+ [Connection] {
+ try
+ {
+ Connection->HandleConnection();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_WARN("exception caught in connection loop: {}", Ex.what());
+ }
+
+ delete Connection;
+ },
+ WorkerThreadPool::EMode::EnableBacklog);
}
else
{
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 8393f399b..b6ffbe467 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -7,7 +7,7 @@ target('zenhttp')
add_files("**.cpp")
add_files("servers/httpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
- add_deps("zencore", "zenutil", "transport-sdk")
+ add_deps("zencore", "transport-sdk")
add_packages(
"vcpkg::asio",
"vcpkg::cpr",
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index 6b855c4db..a2679f92e 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -6,7 +6,7 @@
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
-# include <zenutil/packageformat.h>
+# include <zenhttp/packageformat.h>
namespace zen {
@@ -15,6 +15,7 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ forcelink_packageformat();
}
} // namespace zen