diff options
| author | Stefan Boberg <[email protected]> | 2025-10-06 22:33:00 +0200 |
|---|---|---|
| committer | Stefan Boberg <[email protected]> | 2025-10-06 22:33:00 +0200 |
| commit | 1383dbdc563d90c170ab30ba622ee44e2e37e723 (patch) | |
| tree | 59777db60000fe2ab2334f05776fb9ded4ca41fb /src/zenhttp | |
| parent | Merge branch 'main' into sb/rpc-analysis (diff) | |
| parent | 5.7.6 (diff) | |
| download | zen-1383dbdc563d90c170ab30ba622ee44e2e37e723.tar.xz zen-1383dbdc563d90c170ab30ba622ee44e2e37e723.zip | |
Merge remote-tracking branch 'origin/main' into sb/rpc-analysis
Diffstat (limited to 'src/zenhttp')
32 files changed, 4576 insertions, 1206 deletions
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index 18568a21d..209276621 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -2,12 +2,14 @@ #include "zenhttp/auth/authmgr.h" +#include <zencore/basicfile.h> #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinaryvalidation.h> +#include <zencore/compactbinaryutil.h> #include <zencore/crypto.h> #include <zencore/filesystem.h> #include <zencore/logging.h> +#include <zencore/trace.h> #include <zenhttp/auth/oidc.h> #include <condition_variable> @@ -28,6 +30,8 @@ namespace details { const AesIV128Bit& IV, std::optional<std::string>& Reason) { + ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile"); + FileContents Result = ReadFile(Path); if (Result.ErrorCode) @@ -61,6 +65,8 @@ namespace details { const AesIV128Bit& IV, std::optional<std::string>& Reason) { + ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile"); + if (FileData.GetSize() == 0) { return; @@ -76,7 +82,7 @@ namespace details { return; } - WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize())); + TemporaryFile::SafeWriteFile(Path, EncryptedView); } } // namespace details @@ -99,11 +105,7 @@ public: virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final { - if (OpenIdProviderExist(Params.Name)) - { - ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); - return; - } + ZEN_TRACE_CPU("AuthMgr::AddOpenIdProvider"); if (Params.Name.empty()) { @@ -111,8 +113,26 @@ public: return; } - std::unique_ptr<OidcClient> Client = - std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + { + std::unique_lock _(m_ProviderMutex); + if (auto It = m_OpenIdProviders.find(std::string(Params.Name)); It != m_OpenIdProviders.end()) + { + OpenIdProvider& ExistingProvider = *It->second; + if (ExistingProvider.ClientId == Params.ClientId && ExistingProvider.Url == Params.Url) + { + ZEN_DEBUG("OpenID provider '{}' already exists", Params.Name); + return; + } + else + { + m_OpenIdProviders.erase(It); + m_OpenIdTokens.erase(std::string(Params.Name)); + ZEN_DEBUG("OpenID provider '{}' removed to allow add of new with same name", Params.Name); + } + } + } + + RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -146,6 +166,8 @@ public: virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final { + ZEN_TRACE_CPU("AuthMgr::AddOpenIdToken"); + if (Params.ProviderName.empty()) { ZEN_WARN("trying add OpenID token with invalid provider name"); @@ -208,29 +230,45 @@ public: } private: + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + RefPtr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + bool OpenIdProviderExist(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return m_OpenIdProviders.contains(std::string(ProviderName)); } - OidcClient& GetOpenIdClient(std::string_view ProviderName) + OpenIdProvider GetOpenIdProvider(std::string_view ProviderName) { std::unique_lock _(m_ProviderMutex); - return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + return *m_OpenIdProviders[std::string(ProviderName)]; } OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) { - if (OpenIdProviderExist(ProviderName) == false) + ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken"); + + RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; + if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; } - OidcClient& Client = GetOpenIdClient(ProviderName); - - return Client.RefreshToken(RefreshToken); + return Client->RefreshToken(RefreshToken); } void Shutdown() @@ -241,6 +279,7 @@ private: void LoadState() { + ZEN_TRACE_CPU("AuthMgrImpl::LoadState"); try { std::optional<std::string> Reason; @@ -258,15 +297,14 @@ private: return; } - const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); - - if (ValidationError != CbValidateError::None) + CbValidateError ValidationError; + if (CbObject AuthState = ValidateAndReadCompactBinaryObject(std::move(Buffer), ValidationError); + ValidationError != CbValidateError::None) { - ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + ZEN_WARN("load serialized state FAILED, reason '{}'", ToString(ValidationError)); return; } - - if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + else { for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) { @@ -295,7 +333,7 @@ private: } } } - catch (std::exception& Err) + catch (const std::exception& Err) { ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); @@ -313,6 +351,8 @@ private: void SaveState() { + ZEN_TRACE_CPU("AuthMgr::SaveState"); + try { CbObjectWriter AuthState; @@ -352,7 +392,7 @@ private: AuthState.EndArray(); } - std::filesystem::create_directories(m_Config.RootDirectory); + CreateDirectories(m_Config.RootDirectory); std::optional<std::string> Reason; @@ -367,7 +407,7 @@ private: ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); } } - catch (std::exception& Err) + catch (const std::exception& Err) { ZEN_WARN("serialize state FAILED, reason '{}'", Err.what()); } @@ -474,22 +514,6 @@ private: } }; - struct OpenIdProvider - { - std::string Name; - std::string Url; - std::string ClientId; - std::unique_ptr<OidcClient> HttpClient; - }; - - struct OpenIdToken - { - std::string IdentityToken; - std::string RefreshToken; - std::string AccessToken; - TimePoint ExpireTime{}; - }; - using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp index 6ed587770..f89ca91da 100644 --- a/src/zenhttp/auth/authservice.cpp +++ b/src/zenhttp/auth/authservice.cpp @@ -56,7 +56,7 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) if (Ok) { - ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + ServerRequest.WriteResponse(HttpResponseCode::OK); } else { diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 318110c7d..38e7586ad 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "zenhttp/auth/oidc.h" +#include <zenhttp/httpclient.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/cpr.h> #include <fmt/format.h> #include <json11.hpp> ZEN_THIRD_PARTY_INCLUDES_END @@ -41,27 +41,21 @@ OidcClient::OidcClient(const OidcClient::Options& Options) OidcClient::InitResult OidcClient::Initialize() { - ExtendableStringBuilder<256> Uri; - Uri << m_BaseUrl << "/.well-known/openid-configuration"sv; + HttpClient Http{m_BaseUrl}; + HttpClient::Response Response = Http.Get("/.well-known/openid-configuration"sv); - cpr::Session Session; - - Session.SetOption(cpr::Url{Uri.c_str()}); - - cpr::Response Response = Session.Get(); - - if (Response.error) + if (!Response) { - return {.Reason = std::move(Response.error.message)}; + return {.Reason = Response.ErrorMessage("")}; } - if (Response.status_code != 200) + if (Response.StatusCode != HttpResponseCode::OK) { - return {.Reason = std::move(Response.reason)}; + return {.Reason = std::string{ToString(Response.StatusCode)}}; } std::string JsonError; - json11::Json Json = json11::Json::parse(Response.text, JsonError); + json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError); if (JsonError.empty() == false) { @@ -89,26 +83,24 @@ OidcClient::RefreshToken(std::string_view RefreshToken) { const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); - cpr::Session Session; + HttpClient Http{m_Config.TokenEndpoint}; - Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()}); - Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}); - Session.SetBody(cpr::Body{Body.data(), Body.size()}); + HttpClient::KeyValueMap Headers{{"Content-Type", "application/x-www-form-urlencoded"}}; - cpr::Response Response = Session.Post(); + HttpClient::Response Response = Http.Post("", IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}), Headers); - if (Response.error) + if (!Response) { - return {.Reason = std::move(Response.error.message)}; + return {.Reason = std::string{Response.ErrorMessage("")}}; } - if (Response.status_code != 200) + if (Response.StatusCode != HttpResponseCode::OK) { - return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)}; + return {.Reason = fmt::format("{} ({})", ToString(Response.StatusCode), Response.AsText())}; } std::string JsonError; - json11::Json Json = json11::Json::parse(Response.text, JsonError); + json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError); if (JsonError.empty() == false) { diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp new file mode 100644 index 000000000..8e5136dff --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -0,0 +1,474 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcommon.h" + +#include <fmt/format.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory/memory.h> +#include <zencore/windows.h> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/basicfile.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <fcntl.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +namespace zen { + +using namespace std::literals; + +namespace detail { + + static std::atomic_uint32_t TempFileBaseIndex; + + TempPayloadFile::TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {} + TempPayloadFile::~TempPayloadFile() + { + ZEN_TRACE_CPU("TempPayloadFile::Close"); + try + { + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); + BOOL Success = CloseHandle(m_FileHandle); +#else + std::error_code Ec; + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); + if (Ec) + { + ZEN_WARN("Error reported on get file path from handle {} for temp payload unlink operation, reason '{}'", + m_FileHandle, + Ec.message()); + } + else + { + unlink(FilePath.c_str()); + } + int Fd = int(uintptr_t(m_FileHandle)); + bool Success = (close(Fd) == 0); +#endif + if (!Success) + { + ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); + } + + m_FileHandle = nullptr; + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what()); + } + } + + std::error_code TempPayloadFile::Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize) + { + ZEN_TRACE_CPU("TempPayloadFile::Open"); + ZEN_ASSERT(m_FileHandle == nullptr); + + std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) | + detail::TempFileBaseIndex.fetch_add(1); + + std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex); +#if ZEN_PLATFORM_WINDOWS + LPCWSTR lpFileName = FileName.c_str(); + const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE); + const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr; + const DWORD dwCreationDisposition = CREATE_ALWAYS; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + const HANDLE hTemplateFile = nullptr; + const HANDLE FileHandle = CreateFile(lpFileName, + dwDesiredAccess, + dwShareMode, + lpSecurityAttributes, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + return MakeErrorCodeFromLastError(); + } +#else // ZEN_PLATFORM_WINDOWS + int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC; + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + return MakeErrorCodeFromLastError(); + } + fchmod(Fd, 0666); + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif // ZEN_PLATFORM_WINDOWS + m_FileHandle = FileHandle; + + PrepareFileForScatteredWrite(m_FileHandle, FinalSize); + + return {}; + } + + std::error_code TempPayloadFile::Write(std::string_view DataString) + { + ZEN_TRACE_CPU("TempPayloadFile::Write"); + const uint8_t* DataPtr = (const uint8_t*)DataString.data(); + size_t DataSize = DataString.size(); + if (DataSize >= CacheBufferSize) + { + std::error_code Ec = Flush(); + if (Ec) + { + return Ec; + } + return AppendData(DataPtr, DataSize); + } + size_t CopySize = Min(DataSize, CacheBufferSize - m_CacheBufferOffset); + memcpy(&m_CacheBuffer[m_CacheBufferOffset], DataPtr, CopySize); + m_CacheBufferOffset += CopySize; + DataSize -= CopySize; + if (m_CacheBufferOffset == CacheBufferSize) + { + AppendData(m_CacheBuffer, CacheBufferSize); + if (DataSize > 0) + { + ZEN_ASSERT(DataSize < CacheBufferSize); + memcpy(m_CacheBuffer, DataPtr + CopySize, DataSize); + } + m_CacheBufferOffset = DataSize; + } + else + { + ZEN_ASSERT(DataSize == 0); + } + return {}; + } + + IoBuffer TempPayloadFile::DetachToIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::DetachToIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true); + Buffer.SetDeleteOnClose(true); + m_FileHandle = 0; + m_WriteOffset = 0; + return Buffer; + } + + IoBuffer TempPayloadFile::BorrowIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::BorrowIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, 0, m_WriteOffset); + return Buffer; + } + + void TempPayloadFile::ResetWritePos(uint64_t WriteOffset) + { + ZEN_TRACE_CPU("TempPayloadFile::ResetWritePos"); + Flush(); + m_WriteOffset = WriteOffset; + } + + std::error_code TempPayloadFile::Flush() + { + ZEN_TRACE_CPU("TempPayloadFile::Flush"); + if (m_CacheBufferOffset == 0) + { + return {}; + } + std::error_code Res = AppendData(m_CacheBuffer, m_CacheBufferOffset); + m_CacheBufferOffset = 0; + return Res; + } + + std::error_code TempPayloadFile::AppendData(const void* Data, uint64_t Size) + { + ZEN_TRACE_CPU("TempPayloadFile::AppendData"); + ZEN_ASSERT(m_FileHandle != nullptr); + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); + uint64_t NumberOfBytesWritten = 0; +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten); + } +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset); + bool Success = (BytesWritten > 0); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten); + } +#endif + + if (!Success) + { + return MakeErrorCodeFromLastError(); + } + + Size -= NumberOfBytesWritten; + m_WriteOffset += NumberOfBytesWritten; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten; + } + return {}; + } + + BufferedReadFileStream::BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize) + : m_FileHandle(FileHandle) + , m_FileSize(FileSize) + , m_FileEnd(FileOffset + FileSize) + , m_BufferSize(Min(BufferSize, FileSize)) + , m_FileOffset(FileOffset) + { + } + + BufferedReadFileStream::~BufferedReadFileStream() { Memory::Free(m_Buffer); } + void BufferedReadFileStream::Read(void* Data, uint64_t Size) + { + ZEN_ASSERT(Data != nullptr); + if (Size > m_BufferSize) + { + Read(Data, Size, m_FileOffset); + m_FileOffset += Size; + return; + } + uint8_t* WritePtr = ((uint8_t*)Data); + uint64_t Begin = m_FileOffset; + uint64_t End = m_FileOffset + Size; + ZEN_ASSERT(m_FileOffset >= m_BufferStart); + if (m_FileOffset < m_BufferEnd) + { + ZEN_ASSERT(m_Buffer != nullptr); + uint64_t Count = Min(m_BufferEnd, End) - m_FileOffset; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer + Begin - m_BufferStart, Count); + Begin += Count; + if (Begin == End) + { + m_FileOffset = End; + return; + } + } + if (End == m_FileEnd) + { + Read(WritePtr + Begin - m_FileOffset, End - Begin, Begin); + } + else + { + if (!m_Buffer) + { + m_BufferSize = Min(m_FileEnd - m_FileOffset, m_BufferSize); + m_Buffer = (uint8_t*)Memory::Alloc(gsl::narrow<size_t>(m_BufferSize)); + } + m_BufferStart = Begin; + m_BufferEnd = Min(Begin + m_BufferSize, m_FileEnd); + Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); + uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer, Count); + ZEN_ASSERT(Begin + Count == End); + } + m_FileOffset = End; + } + + void BufferedReadFileStream::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) + { + const uint64_t MaxChunkSize = 1u * 1024 * 1024; + std::error_code Ec; + ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec); + + if (Ec) + { + std::error_code DummyEc; + throw std::system_error( + Ec, + fmt::format("HttpClient::BufferedReadFileStream ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + FileOffset, + BytesToRead, + PathFromHandle(m_FileHandle, DummyEc).generic_string(), + m_FileSize)); + } + } + + CompositeBufferReadStream::CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize) + : m_Data(Data) + , m_BufferSize(BufferSize) + , m_SegmentIndex(0) + , m_BytesLeftInSegment(0) + { + } + uint64_t CompositeBufferReadStream::Read(void* Data, uint64_t Size) + { + uint64_t Result = 0; + uint8_t* WritePtr = (uint8_t*)Data; + while ((Size > 0) && (m_SegmentIndex < m_Data.GetSegments().size())) + { + if (m_BytesLeftInSegment == 0) + { + const SharedBuffer& Segment = m_Data.GetSegments()[m_SegmentIndex]; + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Segment.AsIoBuffer().GetFileReference(FileRef)) + { + m_SegmentDiskBuffer = std::make_unique<BufferedReadFileStream>(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + m_BufferSize); + } + else + { + m_SegmentMemoryBuffer = Segment.GetView(); + } + m_BytesLeftInSegment = Segment.GetSize(); + } + uint64_t BytesToRead = Min(m_BytesLeftInSegment, Size); + if (m_SegmentDiskBuffer) + { + m_SegmentDiskBuffer->Read(WritePtr, BytesToRead); + } + else + { + ZEN_ASSERT_SLOW(m_SegmentMemoryBuffer.GetSize() >= BytesToRead); + memcpy(WritePtr, m_SegmentMemoryBuffer.GetData(), BytesToRead); + m_SegmentMemoryBuffer.MidInline(BytesToRead); + } + WritePtr += BytesToRead; + Size -= BytesToRead; + Result += BytesToRead; + + m_BytesLeftInSegment -= BytesToRead; + if (m_BytesLeftInSegment == 0) + { + m_SegmentDiskBuffer.reset(); + m_SegmentMemoryBuffer.Reset(); + m_SegmentIndex++; + } + } + return Result; + } + +} // namespace detail + +} // namespace zen + +#if ZEN_WITH_TESTS +namespace zen { + +namespace testutil { + IoHash HashComposite(const CompositeBuffer& Payload) + { + IoHashStream Hasher; + const uint64_t PayloadSize = Payload.GetSize(); + std::vector<uint8_t> Buffer(64u * 1024u); + detail::CompositeBufferReadStream Stream(Payload, 137u * 1024u); + for (uint64_t Offset = 0; Offset < PayloadSize;) + { + uint64_t Count = Min(64u * 1024u, PayloadSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + }; + + IoHash HashFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize) + { + IoHashStream Hasher; + std::vector<uint8_t> Buffer(64u * 1024u); + detail::BufferedReadFileStream Stream(FileHandle, FileOffset, FileSize, 137u * 1024u); + for (uint64_t Offset = 0; Offset < FileSize;) + { + uint64_t Count = Min(64u * 1024u, FileSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + } + +} // namespace testutil + +TEST_CASE("BufferedReadFileStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer DiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer1"); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + CHECK(DiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer Partial(DiskBuffer, 37 * 1024, 512 * 1024); + CHECK(Partial.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(Partial), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer SmallDiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(63 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK(SmallDiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(SmallDiskBuffer), + testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); +} + +TEST_CASE("CompositeBufferReadStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer MemoryBuffer1 = CreateRandomBlob(64); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer1)))); + + IoBuffer MemoryBuffer2 = CreateRandomBlob(561 * 1024); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer2)))); + + IoBuffer DiskBuffer1 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(267 * 3 * 1024)), TmpDir.Path() / "diskbuffer1"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer1)))); + + IoBuffer DiskBuffer2 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(3 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer2)))); + + IoBuffer DiskBuffer3 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer3"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer3), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer3)))); + + CompositeBuffer Data(SharedBuffer(std::move(MemoryBuffer1)), + SharedBuffer(std::move(DiskBuffer1)), + SharedBuffer(std::move(DiskBuffer2)), + SharedBuffer(std::move(MemoryBuffer2)), + SharedBuffer(std::move(DiskBuffer3))); + CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h new file mode 100644 index 000000000..9060cde48 --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.h @@ -0,0 +1,147 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compositebuffer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpclient.h> + +namespace zen { + +using namespace std::literals; + +class HttpClientBase +{ +public: + HttpClientBase(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + virtual ~HttpClientBase() = 0; + + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}) = 0; + + LoggerRef Log() { return m_Log; } + std::string_view GetBaseUri() const { return m_BaseUri; } + std::string_view GetSessionId() const { return m_SessionId; } + + bool Authenticate(); + +protected: + LoggerRef m_Log; + std::string m_BaseUri; + std::string m_SessionId; + const HttpClientSettings m_ConnectionSettings; + + const std::optional<HttpClientAccessToken> GetAccessToken(); + + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +namespace detail { + + class TempPayloadFile + { + public: + TempPayloadFile(const TempPayloadFile&) = delete; + TempPayloadFile& operator=(const TempPayloadFile&) = delete; + + TempPayloadFile(); + ~TempPayloadFile(); + + std::error_code Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize); + std::error_code Write(std::string_view DataString); + IoBuffer DetachToIoBuffer(); + IoBuffer BorrowIoBuffer(); + inline uint64_t GetSize() const { return m_WriteOffset; } + void ResetWritePos(uint64_t WriteOffset); + + private: + std::error_code Flush(); + std::error_code AppendData(const void* Data, uint64_t Size); + + void* m_FileHandle; + std::uint64_t m_WriteOffset; + static constexpr uint64_t CacheBufferSize = 512u * 1024u; + uint8_t m_CacheBuffer[CacheBufferSize]; + std::uint64_t m_CacheBufferOffset = 0; + }; + + class BufferedReadFileStream + { + public: + BufferedReadFileStream(const BufferedReadFileStream&) = delete; + BufferedReadFileStream& operator=(const BufferedReadFileStream&) = delete; + + BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize); + ~BufferedReadFileStream(); + + void Read(void* Data, uint64_t Size); + + private: + void Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset); + + void* m_FileHandle = nullptr; + const uint64_t m_FileSize = 0; + const uint64_t m_FileEnd = 0; + uint64_t m_BufferSize = 0; + uint8_t* m_Buffer = nullptr; + uint64_t m_BufferStart = 0; + uint64_t m_BufferEnd = 0; + uint64_t m_FileOffset = 0; + }; + + class CompositeBufferReadStream + { + public: + CompositeBufferReadStream(const CompositeBufferReadStream&) = delete; + CompositeBufferReadStream& operator=(const CompositeBufferReadStream&) = delete; + + CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize); + uint64_t Read(void* Data, uint64_t Size); + + private: + const CompositeBuffer& m_Data; + const uint64_t m_BufferSize; + size_t m_SegmentIndex; + std::unique_ptr<BufferedReadFileStream> m_SegmentDiskBuffer; + MemoryView m_SegmentMemoryBuffer; + uint64_t m_BytesLeftInSegment; + }; + +} // namespace detail + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp new file mode 100644 index 000000000..568106887 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -0,0 +1,1035 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcpr.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/packageformat.h> + +namespace zen { + +HttpClientBase* +CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +{ + return new CprHttpClient(BaseUri, ConnectionSettings); +} + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +// If we want to support different HTTP client implementations then we'll need to make this more abstract + +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) + { + switch ((cpr::ErrorCode)m_Error) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + return ResponseClass::kHttpCantConnectError; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return ResponseClass::kHttpNoHost; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return ResponseClass::kHttpTimeout; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_CACERT_ERROR: + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// CPR helpers + +static cpr::Body +AsCprBody(const CbObject& Obj) +{ + return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); +} + +static cpr::Body +AsCprBody(const IoBuffer& Obj) +{ + return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); +} + +////////////////////////////////////////////////////////////////////////// + +static HttpClient::Response +ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) +{ + // This ends up doing a memcpy, would be good to get rid of it by streaming results + // into buffer directly + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); + + if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) + { + const HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; +} + +static HttpClient::Response +CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); + if (HttpResponse.error) + { + if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && + HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) + { + ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); + } + + // Client side failure code + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed, + .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), + .ErrorMessage = HttpResponse.error.message}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; + } + else + { + return ResponseWithPayload( + SessionId, + HttpResponse, + WorkResponseCode, + Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size())); + } +} + +static bool +ShouldRetry(const cpr::Response& Response) +{ + switch (Response.error.code) + { + case cpr::ErrorCode::OK: + break; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch ((HttpResponseCode)Response.status_code) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +}; + +static bool +ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + + if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second); + if (!ExpectedContentSize.has_value()) + { + Response.error = + cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Response.error = cpr::Error( + /*CURLE_READ_ERROR*/ 26, + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); + return false; + } + } + + if (Response.status_code == (long)HttpResponseCode::PartialContent) + { + return true; + } + + if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, + fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString())); + return false; + } + } + } + + if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) + { + if (ContentType->second == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize)) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); + return false; + } + } + if (ContentType->second == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); + return false; + } + } + } + + return true; +} + +static cpr::Response +DoWithRetry( + std::string_view SessionId, + std::function<cpr::Response()>&& Func, + uint8_t RetryCount, + std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static cpr::Response +DoWithRetry(std::string_view SessionId, + std::function<cpr::Response()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile, + uint8_t RetryCount) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) +: HttpClientBase(BaseUri, Connectionsettings) +{ +} + +CprHttpClient::~CprHttpClient() +{ + ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto CprSession : m_Sessions) + { + delete CprSession; + } + m_Sessions.clear(); + }); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::Session +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_TRACE_CPU("CprHttpClient::AllocSession"); + cpr::Session* CprSession = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + CprSession = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (CprSession == nullptr) + { + CprSession = new cpr::Session(); + CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); + CprSession->SetTimeout(ConnectionSettings.Timeout); + if (ConnectionSettings.AssumeHttp2) + { + CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); + } + } + + if (!AdditionalHeader->empty()) + { + CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); + } + if (!SessionId.empty()) + { + CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); + } + if (AccessToken) + { + CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + } + if (!Parameters->empty()) + { + cpr::Parameters Tmp; + for (auto It = Parameters->begin(); It != Parameters->end(); It++) + { + Tmp.Add({It->first, It->second}); + } + CprSession->SetParameters(Tmp); + } + else + { + CprSession->SetParameters({}); + } + + ExtendableStringBuilder<128> UrlBuffer; + UrlBuffer << BaseUrl << ResourcePath; + CprSession->SetUrl(UrlBuffer.c_str()); + + return Session(this, CprSession); +} + +void +CprHttpClient::ReleaseSession(cpr::Session* CprSession) +{ + ZEN_TRACE_CPU("CprHttpClient::ReleaseSession"); + CprSession->SetUrl({}); + CprSession->SetHeader({}); + CprSession->SetBody({}); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); +} + +CprHttpClient::Response +CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::TransactPackage"); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, + Url, + m_ConnectionSettings, + {{"Content-Length", "0"}}, + Parameters, + m_SessionId, + GetAccessToken()); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Get(); + }, + m_ConnectionSettings.RetryCount, + [](cpr::Response& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + })); +} + +CprHttpClient::Response +CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Head(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Delete(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { + size_t DelimiterPos = header.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string Key = header.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string Value = header.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + + return std::make_pair(Key, Value); + } + return std::make_pair(header, ""); + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + size_t RangeStartPos = RangeIt->second.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = + ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength = RequestedRangeEnd.value() - 1; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > 1024 * 1024) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile, + m_ConnectionSettings.RetryCount); + + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h new file mode 100644 index 000000000..ed9d10c27 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/cprutils.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/body.h> +#include <cpr/session.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CprHttpClient : public HttpClientBase +{ +public: + CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + ~CprHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct Session + { + Session(CprHttpClient* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {} + ~Session() { Outer->ReleaseSession(CprSession); } + + inline cpr::Session* operator->() const { return CprSession; } + inline cpr::Response Get() + { + ZEN_TRACE_CPU("HttpClient::Impl::Get"); + cpr::Response Result = CprSession->Get(); + ZEN_TRACE("GET {}", Result); + return Result; + } + inline cpr::Response Download(cpr::WriteCallback&& Write, std::optional<cpr::HeaderCallback>&& Header = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Download"); + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } + cpr::Response Result = CprSession->Download(Write); + ZEN_TRACE("GET {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); + return Result; + } + inline cpr::Response Head() + { + ZEN_TRACE_CPU("HttpClient::Impl::Head"); + cpr::Response Result = CprSession->Head(); + ZEN_TRACE("HEAD {}", Result); + return Result; + } + inline cpr::Response Put(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Put"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Put(); + ZEN_TRACE("PUT {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Post"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Post(); + ZEN_TRACE("POST {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Delete() + { + ZEN_TRACE_CPU("HttpClient::Impl::Delete"); + cpr::Response Result = CprSession->Delete(); + ZEN_TRACE("DELETE {}", Result); + return Result; + } + + LoggerRef Log() { return Outer->Log(); } + + private: + CprHttpClient* Outer; + cpr::Session* CprSession; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken); + + RwLock m_SessionLock; + std::vector<cpr::Session*> m_Sessions; + + void ReleaseSession(cpr::Session*); +}; + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index a29a08a3c..c5c808c23 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -1,849 +1,383 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include <zenhttp/formatters.h> #include <zenhttp/httpclient.h> #include <zenhttp/httpserver.h> +#include <zenhttp/packageformat.h> -#include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> #include <zencore/compositebuffer.h> #include <zencore/except.h> #include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> +#include <zencore/memory/memory.h> #include <zencore/session.h> #include <zencore/sharedbuffer.h> #include <zencore/stream.h> -#include <zencore/testing.h> +#include <zencore/string.h> #include <zencore/trace.h> -#include <zenhttp/formatters.h> -#include <zenutil/packageformat.h> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/cpr.h> -ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -# include <fcntl.h> -# include <sys/stat.h> -# include <unistd.h> -#endif +#include "clients/httpclientcommon.h" -static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS namespace zen { +extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings); + using namespace std::literals; ////////////////////////////////////////////////////////////////////////// -// -// CPR helpers -cpr::Body -AsCprBody(const CbObject& Obj) -{ - return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); -} - -cpr::Body -AsCprBody(const IoBuffer& Obj) +HttpClientBase::HttpClientBase(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +: m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) +, m_BaseUri(BaseUri) +, m_ConnectionSettings(ConnectionSettings) { - return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); + if (ConnectionSettings.SessionId == Oid::Zero) + { + m_SessionId = GetSessionIdString(); + } + else + { + m_SessionId = ConnectionSettings.SessionId.ToString(); + } } -cpr::Body -AsCprBody(const CompositeBuffer& Buffers) +HttpClientBase::~HttpClientBase() { - SharedBuffer Buffer = Buffers.Flatten(); - - // This is super inefficient, should be fixed - std::string String{(const char*)Buffer.GetData(), Buffer.GetSize()}; - return cpr::Body{std::move(String)}; } -////////////////////////////////////////////////////////////////////////// - -HttpClient::Response -ResponseWithPayload(cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) +bool +HttpClientBase::Authenticate() { - // This ends up doing a memcpy, would be good to get rid of it by streaming results - // into buffer directly - IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); - - if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) - { - const HttpContentType ContentType = ParseContentType(It->second); - - ResponseBuffer.SetContentType(ContentType); - } - - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + ZEN_TRACE_CPU("HttpClientBase::Authenticate"); + std::optional<HttpClientAccessToken> Token = GetAccessToken(); + if (!Token) { - ZEN_WARN("HttpClient request failed: {}", HttpResponse); + return false; } - - return HttpClient::Response{.StatusCode = WorkResponseCode, - .ResponsePayload = std::move(ResponseBuffer), - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + return Token->IsValid(); } -HttpClient::Response -CommonResponse(cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) +const std::optional<HttpClientAccessToken> +HttpClientBase::GetAccessToken() { - const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); - if (HttpResponse.error) + ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); + if (!m_ConnectionSettings.AccessTokenProvider.has_value()) { - ZEN_WARN("HttpClient client error: {}", HttpResponse); - - // Client side failure code - return HttpClient::Response{ - .StatusCode = WorkResponseCode, - .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; + return {}; } - - if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) { - return HttpClient::Response{.StatusCode = WorkResponseCode, - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + RwLock::SharedLockScope _(m_AccessTokenLock); + if (m_CachedAccessToken.IsValid()) + { + return m_CachedAccessToken; + } } - else + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (m_CachedAccessToken.IsValid()) { - return ResponseWithPayload(HttpResponse, WorkResponseCode, std::move(Payload)); + return m_CachedAccessToken; } + m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + return m_CachedAccessToken; } ////////////////////////////////////////////////////////////////////////// -struct HttpClient::Impl : public RefCounted +CbObject +HttpClient::Response::AsObject() const { - Impl(LoggerRef Log); - ~Impl(); - - // Session allocation - - struct Session + if (ResponsePayload) { - Session(Impl* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {} - ~Session() { Outer->ReleaseSession(CprSession); } - - inline cpr::Session* operator->() const { return CprSession; } - inline cpr::Response Get() - { - cpr::Response Result = CprSession->Get(); - ZEN_TRACE("GET {}", Result); - return Result; - } - inline cpr::Response Download(cpr::WriteCallback&& write) - { - cpr::Response Result = CprSession->Download(write); - ZEN_TRACE("GET {}", Result); - return Result; - } - inline cpr::Response Head() - { - cpr::Response Result = CprSession->Head(); - ZEN_TRACE("HEAD {}", Result); - return Result; - } - inline cpr::Response Put() + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(IoBuffer(ResponsePayload), ValidationError); + ValidationError == CbValidateError::None) { - cpr::Response Result = CprSession->Put(); - ZEN_TRACE("PUT {}", Result); - return Result; + return ResponseObject; } - inline cpr::Response Post() - { - cpr::Response Result = CprSession->Post(); - ZEN_TRACE("POST {}", Result); - return Result; - } - inline cpr::Response Delete() - { - cpr::Response Result = CprSession->Delete(); - ZEN_TRACE("DELETE {}", Result); - return Result; - } - - LoggerRef Logger() { return Outer->Logger(); } - - private: - Impl* Outer; - cpr::Session* CprSession; - - Session(Session&&) = delete; - Session& operator=(Session&&) = delete; - }; - - Session AllocSession(const std::string_view BaseUrl, - const std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters); - - LoggerRef Logger() { return m_Log; } - -private: - LoggerRef m_Log; - RwLock m_SessionLock; - std::vector<cpr::Session*> m_Sessions; - - void ReleaseSession(cpr::Session*); -}; + } -HttpClient::Impl::Impl(LoggerRef Log) : m_Log(Log) -{ + return {}; } -HttpClient::Impl::~Impl() +CbPackage +HttpClient::Response::AsPackage() const { - m_SessionLock.WithExclusiveLock([&] { - for (auto CprSession : m_Sessions) - { - delete CprSession; - } - m_Sessions.clear(); - }); + // TODO: sanity checks and error handling + if (ResponsePayload) + { + return ParsePackageMessage(ResponsePayload); + } + + return {}; } -HttpClient::Impl::Session -HttpClient::Impl::AllocSession(const std::string_view BaseUrl, - const std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters) +std::string_view +HttpClient::Response::AsText() const { - bool IsNew = false; - cpr::Session* CprSession = nullptr; - m_SessionLock.WithExclusiveLock([&] { - if (m_Sessions.empty()) - { - CprSession = new cpr::Session(); - IsNew = true; - } - else - { - CprSession = m_Sessions.back(); - m_Sessions.pop_back(); - } - }); - - if (IsNew) + if (ResponsePayload) { - CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); - CprSession->SetTimeout(ConnectionSettings.Timeout); - if (ConnectionSettings.AssumeHttp2) - { - CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); - } + return std::string_view(reinterpret_cast<const char*>(ResponsePayload.GetData()), ResponsePayload.GetSize()); } - if (!AdditionalHeader->empty()) - { - CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); - } - else - { - CprSession->SetHeader({}); - } - if (!Parameters->empty()) - { - cpr::Parameters Tmp; - for (auto It = Parameters->begin(); It != Parameters->end(); It++) - { - Tmp.Add({It->first, It->second}); - } - CprSession->SetParameters(Tmp); - } - else + return {}; +} + +std::string +HttpClient::Response::ToText() const +{ + if (!ResponsePayload) + return {}; + + switch (ResponsePayload.GetContentType()) { - CprSession->SetParameters({}); - } + case ZenContentType::kCbObject: + { + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CbObject Object{SharedBuffer(ResponsePayload)}; + zen::CompactBinaryToJson(Object, ObjStr); + return ObjStr.ToString(); + } + break; - ExtendableStringBuilder<128> UrlBuffer; - UrlBuffer << BaseUrl << ResourcePath; - CprSession->SetUrl(UrlBuffer.c_str()); + case ZenContentType::kCSS: + case ZenContentType::kHTML: + case ZenContentType::kJavaScript: + case ZenContentType::kJSON: + case ZenContentType::kText: + case ZenContentType::kYAML: + return std::string{AsText()}; - return Session(this, CprSession); + default: + return "<unhandled content format>"; + } } -void -HttpClient::Impl::ReleaseSession(cpr::Session* CprSession) +bool +HttpClient::Response::IsSuccess() const noexcept { - CprSession->SetUrl({}); - CprSession->SetHeader({}); - CprSession->SetBody({}); - m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); + return !Error && IsHttpSuccessCode(StatusCode); } -namespace detail { - - static std::atomic_uint32_t TempFileBaseIndex; - -} // namespace detail - -class TempPayloadFile +std::string +HttpClient::Response::ErrorMessage(std::string_view Prefix) const { -public: - TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {} - ~TempPayloadFile() + if (Error.has_value()) { - try - { - if (m_FileHandle) - { -#if ZEN_PLATFORM_WINDOWS - // Mark file for deletion when final handle is closed - FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; - - SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); - BOOL Success = CloseHandle(m_FileHandle); -#else - std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle); - unlink(FilePath.c_str()); - int Fd = int(uintptr_t(m_FileHandle)); - bool Success = (close(Fd) == 0); -#endif - if (!Success) - { - ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); - } - - m_FileHandle = nullptr; - } - } - catch (std::exception& Ex) - { - ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what()); - } + return fmt::format("{}: {}", Prefix, Error->ErrorMessage); } - - std::error_code Open(const std::filesystem::path& TempFolderPath) + else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) { - ZEN_ASSERT(m_FileHandle == nullptr); - - std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) | - detail::TempFileBaseIndex.fetch_add(1); - - std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex); -#if ZEN_PLATFORM_WINDOWS - LPCWSTR lpFileName = FileName.c_str(); - const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE); - const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr; - const DWORD dwCreationDisposition = CREATE_ALWAYS; - const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; - const HANDLE hTemplateFile = nullptr; - const HANDLE FileHandle = CreateFile(lpFileName, - dwDesiredAccess, - dwShareMode, - lpSecurityAttributes, - dwCreationDisposition, - dwFlagsAndAttributes, - hTemplateFile); - - if (FileHandle == INVALID_HANDLE_VALUE) - { - return MakeErrorCodeFromLastError(); - } -#else // ZEN_PLATFORM_WINDOWS - int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC; - int Fd = open(FileName.c_str(), OpenFlags, 0666); - if (Fd < 0) - { - return MakeErrorCodeFromLastError(); - } - fchmod(Fd, 0666); - - void* FileHandle = (void*)(uintptr_t(Fd)); -#endif // ZEN_PLATFORM_WINDOWS - m_FileHandle = FileHandle; - - return {}; + std::string TextResponse = ToText(); + return fmt::format("{}{}HTTP error {} {}{}", + Prefix, + Prefix.empty() ? ""sv : ": "sv, + (int)StatusCode, + zen::ToString(StatusCode), + TextResponse.empty() ? ""sv : fmt::format(" ({})", TextResponse)); } - - std::error_code Write(std::string_view DataString) + else { - ZEN_ASSERT(m_FileHandle != nullptr); - const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; - const void* Data = DataString.data(); - std::size_t Size = DataString.size(); - - while (Size) - { - const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); - uint64_t NumberOfBytesWritten = 0; -#if ZEN_PLATFORM_WINDOWS - OVERLAPPED Ovl{}; - - Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu); - Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32); - - DWORD dwNumberOfBytesWritten = 0; - - BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); - if (Success) - { - NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten); - } -#else - static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); - int Fd = int(uintptr_t(m_FileHandle)); - int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset); - bool Success = (BytesWritten > 0); - if (Success) - { - NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten); - } -#endif - - if (!Success) - { - return MakeErrorCodeFromLastError(); - } - - Size -= NumberOfBytesWritten; - m_WriteOffset += NumberOfBytesWritten; - Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten; - } - return {}; + return fmt::format("{}{}unknown error", Prefix, Prefix.empty() ? ""sv : ": "sv); } +} - IoBuffer DetachToIoBuffer() +void +HttpClient::Response::ThrowError(std::string_view ErrorPrefix) +{ + if (!IsSuccess()) { - ZEN_ASSERT(m_FileHandle != nullptr); - void* FileHandle = m_FileHandle; - IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true); - Buffer.SetDeleteOnClose(true); - m_FileHandle = 0; - m_WriteOffset = 0; - return Buffer; + throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode); } - -private: - void* m_FileHandle; - std::uint64_t m_WriteOffset; -}; +} ////////////////////////////////////////////////////////////////////////// -HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) -: m_Log(zen::logging::Get(Connectionsettings.LogCategory)) -, m_BaseUri(BaseUri) -, m_ConnectionSettings(Connectionsettings) -, m_Impl(new Impl(m_Log)) +HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +: m_BaseUri(BaseUri) +, m_ConnectionSettings(ConnectionSettings) { - StringBuilder<32> SessionId; - GetSessionId().ToString(SessionId); - m_SessionId = SessionId; + m_SessionId = GetSessionIdString(); + + m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings); } HttpClient::~HttpClient() { + delete m_Inner; } -HttpClient::Response -HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +void +HttpClient::SetSessionId(const Oid& SessionId) { - ZEN_TRACE_CPU("HttpClient::TransactPackage"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - // First, list of offered chunks for filtering on the server end - - std::vector<IoHash> AttachmentsToSend; - std::span<const CbAttachment> Attachments = Package.GetAttachments(); - - const uint32_t RequestId = ++HttpClientRequestIdCounter; - auto RequestIdString = fmt::to_string(RequestId); - - if (Attachments.empty() == false) - { - CbObjectWriter Writer; - Writer.BeginArray("offer"); - - for (const CbAttachment& Attachment : Attachments) - { - Writer.AddHash(Attachment.GetHash()); - } - - Writer.EndArray(); - - BinaryWriter MemWriter; - Writer.Save(MemWriter); - - Sess->UpdateHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); - Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (FilterResponse.status_code == 200) - { - IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); - CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer); - - for (CbFieldView& Entry : ResponseObject["need"]) - { - ZEN_ASSERT(Entry.IsHash()); - AttachmentsToSend.push_back(Entry.AsHash()); - } - } - } - - // Prepare package for send - - CbPackage SendPackage; - SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); - - for (const IoHash& AttachmentCid : AttachmentsToSend) - { - const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); - - if (Attachment) - { - SendPackage.AddAttachment(*Attachment); - } - else - { - // This should be an error -- server asked to have something we can't find - } - } - - // Transmit package payload - - CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); - SharedBuffer FlatMessage = Message.Flatten(); - - Sess->UpdateHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); - Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (!IsHttpSuccessCode(FilterResponse.status_code)) + if (SessionId == Oid::Zero) { - return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; + m_SessionId = GetSessionIdString(); } - - IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); - - if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + else { - HttpContentType ContentType = ParseContentType(It->second); - - ResponseBuffer.SetContentType(ContentType); + m_SessionId = SessionId.ToString(); } - - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; } -////////////////////////////////////////////////////////////////////////// -// -// Standard HTTP verbs -// - HttpClient::Response -HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::Put"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}}); - - return CommonResponse(Sess.Put()); + return m_Inner->Put(Url, Payload, AdditionalHeader); } HttpClient::Response -HttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +HttpClient::Put(std::string_view Url, const HttpClient::KeyValueMap& Parameters) { - ZEN_TRACE_CPU("HttpClient::Get"); - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters); - - return CommonResponse(Sess.Get()); + return m_Inner->Put(Url, Parameters); } HttpClient::Response -HttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +HttpClient::Get(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) { - ZEN_TRACE_CPU("HttpClient::Head"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - return CommonResponse(Sess.Head()); + return m_Inner->Get(Url, AdditionalHeader, Parameters); } HttpClient::Response -HttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +HttpClient::Head(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::Delete"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - return CommonResponse(Sess.Delete()); + return m_Inner->Head(Url, AdditionalHeader); } HttpClient::Response -HttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +HttpClient::Delete(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::PostNoPayload"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters); - - return CommonResponse(Sess.Post()); + return m_Inner->Delete(Url, AdditionalHeader); } HttpClient::Response -HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) { - ZEN_TRACE_CPU("HttpClient::PostWithPayload"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}}); - - return CommonResponse(Sess.Post()); + return m_Inner->Post(Url, AdditionalHeader, Parameters); } HttpClient::Response -HttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::PostObjectPayload"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ZenContentType::kCbObject))}}); - - return CommonResponse(Sess.Post()); + return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response -HttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::PostPackage"); - - CompositeBuffer Message = zen::FormatPackageMessageBuffer(Pkg); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - Sess->SetBody(AsCprBody(Message)); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ZenContentType::kCbPackage))}}); - - return CommonResponse(Sess.Post()); + return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); } HttpClient::Response -HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::Upload"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(Payload.GetContentType()))}}); - - uint64_t Offset = 0; - if (Payload.IsWholeFile()) - { - auto ReadCallback = [&Payload, &Offset](char* buffer, size_t& size, intptr_t) { - size = Min<size_t>(size, Payload.GetSize() - Offset); - IoBuffer PayloadRange = IoBuffer(Payload, Offset, size); - MutableMemoryView Data(buffer, size); - Data.CopyFrom(PayloadRange.GetView()); - Offset += size; - return true; - }; - Sess->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - } - else - { - Sess->SetBody(AsCprBody(Payload)); - } - return CommonResponse(Sess.Put()); + return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response -HttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, CbPackage Payload, const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::Upload"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - Sess->UpdateHeader(cpr::Header{{"Content-Type", std::string(MapContentTypeToString(ContentType))}}); - - uint64_t SizeLeft = Payload.GetSize(); - CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); - auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { - size = Min<size_t>(size, SizeLeft); - MutableMemoryView Data(buffer, size); - Payload.CopyTo(Data, BufferIt); - SizeLeft -= size; - return true; - }; - Sess->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - - return CommonResponse(Sess.Put()); + return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response -HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const HttpClient::KeyValueMap& AdditionalHeader) { - ZEN_TRACE_CPU("HttpClient::Download"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}); - - std::string PayloadString; - std::unique_ptr<TempPayloadFile> PayloadFile; - - cpr::Response Response = Sess.Download(cpr::WriteCallback{[&](std::string data, intptr_t) { - if (!PayloadFile && (PayloadString.length() + data.length()) > (1024 * 1024)) - { - PayloadFile = std::make_unique<TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", TempFolderPath.string(), Ec.message()); - return false; - } - PayloadFile->Write(PayloadString); - PayloadString.clear(); - } - if (PayloadFile) - { - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }}); - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - - return CommonResponse(std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); } -////////////////////////////////////////////////////////////////////////// - -CbObject -HttpClient::Response::AsObject() const +HttpClient::Response +HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { - // TODO: sanity check the payload format etc - - if (ResponsePayload) - { - return LoadCompactBinaryObject(ResponsePayload); - } - - return {}; + return m_Inner->Upload(Url, Payload, AdditionalHeader); } -CbPackage -HttpClient::Response::AsPackage() const +HttpClient::Response +HttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const HttpClient::KeyValueMap& AdditionalHeader) { - // TODO: sanity checks and error handling - if (ResponsePayload) - { - return ParsePackageMessage(ResponsePayload); - } - - return {}; + return m_Inner->Upload(Url, Payload, ContentType, AdditionalHeader); } -std::string_view -HttpClient::Response::AsText() const +HttpClient::Response +HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const HttpClient::KeyValueMap& AdditionalHeader) { - if (ResponsePayload) - { - return std::string_view(reinterpret_cast<const char*>(ResponsePayload.GetData()), ResponsePayload.GetSize()); - } - - return {}; + return m_Inner->Download(Url, TempFolderPath, AdditionalHeader); } -std::string -HttpClient::Response::ToText() const +HttpClient::Response +HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const HttpClient::KeyValueMap& AdditionalHeader) { - if (!ResponsePayload) - return {}; - - switch (ResponsePayload.GetContentType()) - { - case ZenContentType::kCbObject: - { - zen::ExtendableStringBuilder<1024> ObjStr; - zen::CbObject Object{SharedBuffer(ResponsePayload)}; - zen::CompactBinaryToJson(Object, ObjStr); - return ObjStr.ToString(); - } - break; - - case ZenContentType::kCSS: - case ZenContentType::kHTML: - case ZenContentType::kJavaScript: - case ZenContentType::kJSON: - case ZenContentType::kText: - case ZenContentType::kYAML: - return std::string{AsText()}; - - default: - return "<unhandled content format>"; - } + return m_Inner->TransactPackage(Url, Package, AdditionalHeader); } bool -HttpClient::Response::IsSuccess() const noexcept +HttpClient::Authenticate() { - return !Error && IsHttpSuccessCode(StatusCode); + return m_Inner->Authenticate(); } -std::string -HttpClient::Response::ErrorMessage(std::string_view Prefix) const +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("responseformat") { - if (Error.has_value()) - { - return fmt::format("{}: {}", Prefix, Error->ErrorMessage); - } - else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) + using namespace std::literals; + + SUBCASE("identity") { - return fmt::format("{}: HTTP error {} {} ({})", Prefix, (int)StatusCode, zen::ToString(StatusCode), AsText()); + BodyLogFormatter _{"abcd"}; + CHECK_EQ(_.GetText(), "abcd"sv); } - else + + SUBCASE("very long") { - return fmt::format("{}: {}", Prefix, "unknown error"); + std::string_view LongView = + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"; + + BodyLogFormatter _{LongView}; + + CHECK(_.GetText().size() < LongView.size()); + CHECK(_.GetText().starts_with("[truncated"sv)); } -} -void -HttpClient::Response::ThrowError(std::string_view ErrorPrefix) -{ - if (!IsSuccess()) + SUBCASE("invalid text") { - throw std::runtime_error(ErrorMessage(ErrorPrefix)); - } -} + std::string_view BadText = "totobaba\xff\xfe"; -////////////////////////////////////////////////////////////////////////// + BodyLogFormatter _{BadText}; -#if ZEN_WITH_TESTS + CHECK_EQ(_.GetText(), "totobaba"); + } +} TEST_CASE("httpclient") { diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp new file mode 100644 index 000000000..8754c57d6 --- /dev/null +++ b/src/zenhttp/httpclientauth.cpp @@ -0,0 +1,212 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclientauth.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> +#include <zencore/uid.h> +#include <zenhttp/auth/authmgr.h> +#include <zenhttp/httpclient.h> + +#include <ctime> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# define timegm _mkgmtime +#endif // ZEN_PLATFORM_WINDOWS + +namespace zen { namespace httpclientauth { + + using namespace std::literals; + + std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token) + { + return [Token]() { return Token; }; + } + + std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token) + { + return CreateFromStaticToken( + HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = HttpClientAccessToken::TimePoint::max()}); + } + + std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) + { + OAuthClientCredentialsParams OAuthParams(Params); + return [OAuthParams]() { + using namespace std::chrono; + + std::string Body = fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}"sv, + OAuthParams.ClientId, + OAuthParams.ClientSecret); + + HttpClient Http{OAuthParams.Url}; + + IoBuffer Payload{IoBuffer::Wrap, Body.data(), Body.size()}; + + // TODO: ensure this gets the right Content-Type passed along + + HttpClient::Response Response = Http.Post("", Payload, {{"Content-Type", "application/x-www-form-urlencoded"}}); + + if (!Response || Response.StatusCode != HttpResponseCode::OK) + { + ZEN_WARN("Failed fetching OAuth access token {}. Reason: '{}'", OAuthParams.Url, Response.ErrorMessage("")); + return HttpClientAccessToken{}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(std::string{Response.AsText()}, JsonError); + + if (JsonError.empty() == false) + { + ZEN_WARN("Unable to parse OAuth json response from {}. Reason: '{}'", OAuthParams.Url, JsonError); + return HttpClientAccessToken{}; + } + + std::string Token = Json["access_token"].string_value(); + int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); + HttpClientAccessToken::TimePoint ExpireTime = HttpClientAccessToken::Clock::now() + seconds(ExpiresInSeconds); + + return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + }; + } + + std::function<HttpClientAccessToken()> CreateFromOpenIdProvider(AuthMgr& AuthManager, std::string_view OpenIdProvider) + { + return [&AuthManager = AuthManager, OpenIdProvider = std::string(OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider); + return HttpClientAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }; + } + + std::function<HttpClientAccessToken()> CreateFromDefaultOpenIdProvider(AuthMgr& AuthManager) + { + return CreateFromOpenIdProvider(AuthManager, "Default"sv); + } + + static HttpClientAccessToken GetOidcTokenFromExe(const std::filesystem::path& OidcExecutablePath, + std::string_view CloudHost, + bool Unattended, + bool Quiet, + bool Hidden) + { + Stopwatch Timer; + + CreateProcOptions ProcOptions; + if (Quiet) + { + ProcOptions.StdoutFile = std::filesystem::temp_directory_path() / fmt::format(".zen-auth-output-{}", Oid::NewOid()); + } + if (Hidden) + { + ProcOptions.Flags |= CreateProcOptions::Flag_NoConsole; + } + + const std::filesystem::path AuthTokenPath(std::filesystem::temp_directory_path() / fmt::format(".zen-auth-{}", Oid::NewOid())); + auto _ = MakeGuard([AuthTokenPath, &ProcOptions]() { + RemoveFile(AuthTokenPath); + if (!ProcOptions.StdoutFile.empty()) + { + RemoveFile(ProcOptions.StdoutFile); + } + }); + + const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + OidcExecutablePath, + CloudHost, + AuthTokenPath, + Unattended ? "true"sv : "false"sv); + ZEN_DEBUG("Running: {}", ProcArgs); + ProcessHandle Proc; + Proc.Initialize(CreateProc(OidcExecutablePath, ProcArgs, ProcOptions)); + if (!Proc.IsValid()) + { + throw std::runtime_error(fmt::format("failed to launch '{}'", OidcExecutablePath)); + } + + int ExitCode = Proc.WaitExitCode(); + + auto EndTime = std::chrono::system_clock::now(); + + if (ExitCode == 0) + { + IoBuffer Body = IoBufferBuilder::MakeFromFile(AuthTokenPath); + std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + if (JsonError.empty() == false) + { + ZEN_WARN("Unable to parse Oidcs json response from {}. Reason: '{}'", AuthTokenPath, JsonError); + return HttpClientAccessToken{}; + } + std::string Token = Json["Token"].string_value(); + std::string ExpiresAtUTCString = Json["ExpiresAtUtc"].string_value(); + ZEN_ASSERT(!ExpiresAtUTCString.empty()); + + int Year = 0; + int Month = 0; + int Day = 0; + int Hour = 0; + int Minute = 0; + int Second = 0; + int Millisecond = 0; + sscanf(ExpiresAtUTCString.c_str(), "%d-%d-%dT%d:%d:%d.%dZ", &Year, &Month, &Day, &Hour, &Minute, &Second, &Millisecond); + + std::tm Time = { + Second, + Minute, + Hour, + Day, + Month - 1, + Year - 1900, + }; + + time_t UTCTime = timegm(&Time); + HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); + ExpireTime += std::chrono::microseconds(Millisecond); + + return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; + } + else + { + ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + } + return HttpClientAccessToken{}; + } + + std::optional<std::function<HttpClientAccessToken()>> CreateFromOidcTokenExecutable(const std::filesystem::path& OidcExecutablePath, + std::string_view CloudHost, + bool Quiet, + bool Unattended, + bool Hidden) + { + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + if (InitialToken.IsValid()) + { + return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), + CloudHost = std::string(CloudHost), + Quiet, + Hidden, + InitialToken]() mutable { + if (InitialToken.IsValid()) + { + HttpClientAccessToken Result = InitialToken; + InitialToken = {}; + return Result; + } + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, /* Unattended */ true, Quiet, Hidden); + }; + } + return {}; + } + +}} // namespace zen::httpclientauth diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 3270855ad..2c063d646 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -18,19 +18,22 @@ #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> #include <zencore/iobuffer.h> #include <zencore/logging.h> #include <zencore/stream.h> #include <zencore/string.h> #include <zencore/testing.h> #include <zencore/thread.h> -#include <zenutil/packageformat.h> +#include <zenhttp/packageformat.h> #include <charconv> #include <mutex> #include <span> #include <string_view> +#include <EASTL/fixed_vector.h> + namespace zen { using namespace std::literals; @@ -94,6 +97,7 @@ MapContentTypeToString(HttpContentType ContentType) static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); static constinit uint32_t HashJson = HashStringDjb2("json"sv); static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv); +static constinit uint32_t HashApplicationProblemJson = HashStringDjb2("application/problem+json"sv); static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv); static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv); static constinit uint32_t HashText = HashStringDjb2("text/plain"sv); @@ -132,6 +136,7 @@ struct HashedTypeEntry {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer}, {HashJson, HttpContentType::kJSON}, {HashApplicationJson, HttpContentType::kJSON}, + {HashApplicationProblemJson, HttpContentType::kJSON}, {HashYaml, HttpContentType::kYAML}, {HashTextYaml, HttpContentType::kYAML}, {HashText, HttpContentType::kText}, @@ -156,7 +161,14 @@ ParseContentTypeImpl(const std::string_view& ContentTypeString) { if (!ContentTypeString.empty()) { - const uint32_t CtHash = HashStringDjb2(ContentTypeString); + size_t ContentEnd = ContentTypeString.find(';'); + if (ContentEnd == std::string_view::npos) + { + ContentEnd = ContentTypeString.length(); + } + std::string_view ContentString(ContentTypeString.substr(0, ContentEnd)); + + const uint32_t CtHash = HashStringDjb2(ContentString); if (auto It = std::lower_bound(std::begin(TypeHashTable), std::end(TypeHashTable), @@ -468,6 +480,11 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView()); } + else if (m_AcceptType == HttpContentType::kYAML) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kYAML, Data.ToYaml(Sb).ToView()); + } else { SharedBuffer Buf = Data.GetBuffer(); @@ -484,6 +501,11 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array) ExtendableStringBuilder<1024> Sb; WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView()); } + else if (m_AcceptType == HttpContentType::kYAML) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kYAML, Array.ToYaml(Sb).ToView()); + } else { SharedBuffer Buf = Array.GetBuffer(); @@ -510,7 +532,7 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType { std::span<const SharedBuffer> Segments = Payload.GetSegments(); - std::vector<IoBuffer> Buffers; + eastl::fixed_vector<IoBuffer, 64> Buffers; Buffers.reserve(Segments.size()); for (auto& Segment : Segments) @@ -518,7 +540,41 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType Buffers.push_back(Segment.AsIoBuffer()); } - WriteResponse(ResponseCode, ContentType, Buffers); + WriteResponse(ResponseCode, ContentType, std::span<IoBuffer>(begin(Buffers), end(Buffers))); +} + +std::string +HttpServerRequest::Decode(std::string_view PercentEncodedString) +{ + size_t Length = PercentEncodedString.length(); + std::string Decoded; + Decoded.reserve(Length); + size_t Offset = 0; + while (Offset < Length) + { + char C = PercentEncodedString[Offset]; + if (C == '%' && (Offset <= (Length - 3))) + { + std::string_view CharHash(&PercentEncodedString[Offset + 1], 2); + uint8_t DecodedChar = 0; + if (ParseHexBytes(CharHash, &DecodedChar)) + { + Decoded.push_back((char)DecodedChar); + Offset += 3; + } + else + { + Decoded.push_back(C); + Offset++; + } + } + else + { + Decoded.push_back(C); + Offset++; + } + } + return Decoded; } HttpServerRequest::QueryParams @@ -610,9 +666,13 @@ HttpServerRequest::ReadPayloadObject() } return CbObject(); } - return LoadCompactBinaryObject(std::move(Payload)); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(Payload), ValidationError); + ValidationError == CbValidateError::None) + { + return ResponseObject; + } } - return {}; } @@ -732,120 +792,131 @@ HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcA ////////////////////////////////////////////////////////////////////////// -enum class HttpServerClass -{ - kHttpAsio, - kHttpSys, - kHttpPlugin, - kHttpMulti, - kHttpNull -}; - Ref<HttpServer> -CreateHttpServerClass(HttpServerClass Class, const HttpServerConfig& Config) +CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig& Config) { - switch (Class) + if (ServerClass == "asio"sv) { - default: - case HttpServerClass::kHttpAsio: - ZEN_INFO("using asio HTTP server implementation"); - return CreateHttpAsioServer(Config.ForceLoopback, Config.ThreadCount); - - case HttpServerClass::kHttpMulti: - { - ZEN_INFO("using multi HTTP server implementation"); - Ref<HttpMultiServer> Server{new HttpMultiServer()}; - - // This is hardcoded for now, but should be configurable in the future - Server->AddServer(CreateHttpServerClass(HttpServerClass::kHttpSys, Config)); - Server->AddServer(CreateHttpServerClass(HttpServerClass::kHttpPlugin, Config)); + ZEN_INFO("using asio HTTP server implementation") + return CreateHttpAsioServer(Config.ForceLoopback, Config.ThreadCount); + } +#if ZEN_WITH_HTTPSYS + else if (ServerClass == "httpsys"sv) + { + ZEN_INFO("using http.sys server implementation") + return Ref<HttpServer>(CreateHttpSysServer({.ThreadCount = Config.ThreadCount, + .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount, + .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, + .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled, + .IsDedicatedServer = Config.IsDedicatedServer, + .ForceLoopback = Config.ForceLoopback})); + } +#endif + else if (ServerClass == "null"sv) + { + ZEN_INFO("using null HTTP server implementation") + return Ref<HttpServer>(new HttpNullServer); + } + else + { + ZEN_WARN("unknown HTTP server implementation '{}', falling back to default", ServerClass) - return Server; - } +#if ZEN_WITH_HTTPSYS + return CreateHttpServerClass("httpsys"sv, Config); +#else + return CreateHttpServerClass("asio"sv, Config); +#endif + } +} #if ZEN_WITH_PLUGINS - case HttpServerClass::kHttpPlugin: - { - ZEN_INFO("using plugin HTTP server implementation"); - Ref<HttpPluginServer> Server{CreateHttpPluginServer()}; +Ref<HttpServer> +CreateHttpServerPlugin(const HttpServerPluginConfig& PluginConfig) +{ + const std::string& PluginName = PluginConfig.PluginName; - // This is hardcoded for now, but should be configurable in the future + ZEN_INFO("using '{}' plugin HTTP server implementation", PluginName) + if (PluginName.starts_with("builtin:"sv)) + { # if 0 - Ref<TransportPlugin> WinsockPlugin{CreateSocketTransportPlugin()}; - WinsockPlugin->Configure("port", "8558"); - Server->AddPlugin(WinsockPlugin); -# endif + Ref<TransportPlugin> Plugin = {}; + if (PluginName == "builtin:winsock"sv) + { + Plugin = CreateSocketTransportPlugin(); + } + else if (PluginName == "builtin:asio"sv) + { + Plugin = CreateAsioTransportPlugin(); + } + else + { + ZEN_WARN("Unknown builtin plugin '{}'", PluginName) + return {}; + } -# if 0 - Ref<TransportPlugin> AsioPlugin{CreateAsioTransportPlugin()}; - AsioPlugin->Configure("port", "8558"); - Server->AddPlugin(AsioPlugin); -# endif + ZEN_ASSERT(!Plugin.IsNull()); -# if 1 - Ref<DllTransportPlugin> DllPlugin{CreateDllTransportPlugin()}; - DllPlugin->LoadDll("winsock"); - DllPlugin->ConfigureDll("winsock", "port", "8558"); - Server->AddPlugin(DllPlugin); -# endif + for (const std::pair<std::string, std::string>& Option : PluginConfig.PluginOptions) + { + Plugin->Configure(Option.first.c_str(), Option.second.c_str()); + } - return Server; - } -#endif + Ref<HttpPluginServer> Server{CreateHttpPluginServer()}; + Server->AddPlugin(Plugin); + return Server; +# else + ZEN_WARN("Builtin plugin '{}' is not supported", PluginName) + return {}; +# endif + } -#if ZEN_WITH_HTTPSYS - case HttpServerClass::kHttpSys: - ZEN_INFO("using http.sys server implementation"); - return Ref<HttpServer>(CreateHttpSysServer({.ThreadCount = Config.ThreadCount, - .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount, - .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, - .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled, - .IsDedicatedServer = Config.IsDedicatedServer, - .ForceLoopback = Config.ForceLoopback})); -#endif + Ref<DllTransportPlugin> DllPlugin{CreateDllTransportPlugin()}; + if (!DllPlugin->LoadDll(PluginName)) + { + return {}; + } - case HttpServerClass::kHttpNull: - ZEN_INFO("using null HTTP server implementation"); - return Ref<HttpServer>(new HttpNullServer); + for (const std::pair<std::string, std::string>& Option : PluginConfig.PluginOptions) + { + DllPlugin->ConfigureDll(PluginName, Option.first.c_str(), Option.second.c_str()); } + + Ref<HttpPluginServer> Server{CreateHttpPluginServer()}; + Server->AddPlugin(DllPlugin); + return Server; } +#endif Ref<HttpServer> CreateHttpServer(const HttpServerConfig& Config) { using namespace std::literals; - HttpServerClass Class = HttpServerClass::kHttpNull; - -#if ZEN_WITH_HTTPSYS - Class = HttpServerClass::kHttpSys; -#else - Class = HttpServerClass::kHttpAsio; -#endif - - if (Config.ServerClass == "asio"sv) - { - Class = HttpServerClass::kHttpAsio; - } - else if (Config.ServerClass == "httpsys"sv) - { - Class = HttpServerClass::kHttpSys; - } - else if (Config.ServerClass == "plugin"sv) - { - Class = HttpServerClass::kHttpPlugin; - } - else if (Config.ServerClass == "null"sv) +#if ZEN_WITH_PLUGINS + if (Config.PluginConfigs.empty()) { - Class = HttpServerClass::kHttpNull; + return CreateHttpServerClass(Config.ServerClass, Config); } - else if (Config.ServerClass == "multi"sv) + else { - Class = HttpServerClass::kHttpMulti; - } + Ref<HttpMultiServer> Server{new HttpMultiServer()}; + Server->AddServer(CreateHttpServerClass(Config.ServerClass, Config)); - return CreateHttpServerClass(Class, Config); + for (const HttpServerPluginConfig& PluginConfig : Config.PluginConfigs) + { + Ref<HttpServer> PluginServer = CreateHttpServerPlugin(PluginConfig); + if (!PluginServer.IsNull()) + { + Server->AddServer(PluginServer); + } + } + + return Server; + } +#else + return CreateHttpServerClass(Config.ServerClass, Config); +#endif } ////////////////////////////////////////////////////////////////////////// @@ -865,42 +936,51 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP if (PackageHandlerRef) { - CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload()); - - std::vector<IoHash> OfferCids; - - for (auto& CidEntry : OfferMessage["offer"]) + CbValidateError ValidationError = CbValidateError::None; + if (CbObject OfferMessage = ValidateAndReadCompactBinaryObject(IoBuffer(Request.ReadPayload()), ValidationError); + ValidationError == CbValidateError::None) { - if (!CidEntry.IsHash()) + std::vector<IoHash> OfferCids; + + for (auto& CidEntry : OfferMessage["offer"]) { - // Should yield bad request response? + if (!CidEntry.IsHash()) + { + // Should yield bad request response? + + ZEN_WARN("found invalid entry in offer"); - ZEN_WARN("found invalid entry in offer"); + continue; + } - continue; + OfferCids.push_back(CidEntry.AsHash()); } - OfferCids.push_back(CidEntry.AsHash()); - } + ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); - ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); + PackageHandlerRef->FilterOffer(OfferCids); - PackageHandlerRef->FilterOffer(OfferCids); + ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); - ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); - CbObjectWriter ResponseWriter; - ResponseWriter.BeginArray("need"); + for (const IoHash& Cid : OfferCids) + { + ResponseWriter.AddHash(Cid); + } + + ResponseWriter.EndArray(); - for (const IoHash& Cid : OfferCids) + // Emit filter response + Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + else { - ResponseWriter.AddHash(Cid); + Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Invalid request payload: '{}'", ToString(ValidationError))); } - - ResponseWriter.EndArray(); - - // Emit filter response - Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); return true; } } diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h index f43ae3cd7..6f9c3198e 100644 --- a/src/zenhttp/include/zenhttp/auth/oidc.h +++ b/src/zenhttp/include/zenhttp/auth/oidc.h @@ -2,13 +2,14 @@ #pragma once +#include <zenbase/refcount.h> #include <zencore/string.h> #include <vector> namespace zen { -class OidcClient +class OidcClient : public RefCounted { public: struct Options diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h new file mode 100644 index 000000000..a3b870c0f --- /dev/null +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -0,0 +1,86 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> +#include <zenhttp/formatters.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/response.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +template<> +struct fmt::formatter<cpr::Response> +{ + constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); } + + template<typename FormatContext> + auto format(const cpr::Response& Response, FormatContext& Ctx) const -> decltype(Ctx.out()) + { + using namespace std::literals; + + zen::NiceTimeSpanMs NiceResponseTime(uint64_t(Response.elapsed * 1000)); + + if (zen::IsHttpSuccessCode(Response.status_code)) + { + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Error: '{}' ({}), Bytes: {}/{} (Up/Down), Elapsed: {}", + Response.url.str(), + Response.status_code, + Response.error.message, + int(Response.error.code), + Response.uploaded_bytes, + Response.downloaded_bytes, + NiceResponseTime.c_str()); + } + else + { + const auto It = Response.header.find("Content-Type"); + const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv; + + if (ContentType == "application/x-ue-cb"sv) + { + zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + zen::CbObjectView Obj(Body.Data()); + zen::ExtendableStringBuilder<256> Sb; + std::string_view Json = Obj.ToJson(Sb).ToView(); + + return fmt::format_to( + Ctx.out(), + "Url: {}, Status: {}, Error: '{}' ({}). Bytes: {}/{} (Up/Down), Elapsed: {}, Response: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.error.message, + int(Response.error.code), + Response.uploaded_bytes, + Response.downloaded_bytes, + NiceResponseTime.c_str(), + Json, + Response.reason); + } + else + { + zen::BodyLogFormatter Body(Response.text); + + return fmt::format_to( + Ctx.out(), + "Url: {}, Status: {}, Error: '{}' ({}), Bytes: {}/{} (Up/Down), Elapsed: {}, Response: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.error.message, + int(Response.error.code), + Response.uploaded_bytes, + Response.downloaded_bytes, + NiceResponseTime.c_str(), + Body.GetText(), + Response.reason); + } + } + } +}; diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index d45f5fbb2..0af31fa30 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -7,77 +7,63 @@ #include <zencore/iobuffer.h> #include <zencore/string.h> #include <zenhttp/httpclient.h> +#include <zenhttp/httpcommon.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/cpr.h> #include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END -template<> -struct fmt::formatter<cpr::Response> +namespace zen { + +struct BodyLogFormatter { - constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); } +private: + std::string_view ResponseText; + zen::ExtendableStringBuilder<128> ModifiedResponse; - template<typename FormatContext> - auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out()) +public: + explicit BodyLogFormatter(std::string_view InResponseText) : ResponseText(InResponseText) { using namespace std::literals; - if (Response.status_code == 200 || Response.status_code == 201) + const int TextSizeLimit = 1024; + + // Trim invalid UTF8 + + auto InvalidIt = zen::FindFirstInvalidUtf8Byte(ResponseText); + + if (InvalidIt != end(ResponseText)) { - return fmt::format_to(Ctx.out(), - "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", - Response.url.str(), - Response.status_code, - Response.uploaded_bytes, - Response.downloaded_bytes, - Response.elapsed); + ResponseText = ResponseText.substr(0, InvalidIt - begin(ResponseText)); } - else + + if (ResponseText.empty()) + { + ResponseText = "<suppressed non-text response>"sv; + } + + if (ResponseText.size() > TextSizeLimit) { - const auto It = Response.header.find("Content-Type"); - const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv; - - if (ContentType == "application/x-ue-cb"sv) - { - zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); - zen::CbObjectView Obj(Body.Data()); - zen::ExtendableStringBuilder<256> Sb; - std::string_view Json = Obj.ToJson(Sb).ToView(); - - return fmt::format_to(Ctx.out(), - "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'", - Response.url.str(), - Response.status_code, - Response.uploaded_bytes, - Response.downloaded_bytes, - Response.elapsed, - Json, - Response.reason); - } - else - { - return fmt::format_to(Ctx.out(), - "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'", - Response.url.str(), - Response.status_code, - Response.uploaded_bytes, - Response.downloaded_bytes, - Response.elapsed, - Response.text, - Response.reason); - } + const auto TruncatedString = "[truncated response] "sv; + ModifiedResponse.Append(TruncatedString); + ModifiedResponse.Append(ResponseText.data(), TextSizeLimit - TruncatedString.size()); + + ResponseText = ModifiedResponse; } } + + inline std::string_view GetText() const { return ResponseText; } }; +} // namespace zen + template<> struct fmt::formatter<zen::HttpClient::Response> { constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); } template<typename FormatContext> - auto format(const zen::HttpClient::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out()) + auto format(const zen::HttpClient::Response& Response, FormatContext& Ctx) const -> decltype(Ctx.out()) { using namespace std::literals; diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 9de5c7cce..c1fc1efa6 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -6,9 +6,11 @@ #include <zencore/iobuffer.h> #include <zencore/logbase.h> +#include <zencore/thread.h> #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <functional> #include <optional> #include <unordered_map> @@ -27,20 +29,91 @@ class CompositeBuffer; */ +struct HttpClientAccessToken +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + + static constexpr int64_t ExpireMarginInSeconds = 60 * 5; + + std::string Value; + TimePoint ExpireTime; + + bool IsValid() const + { + return Value.empty() == false && + ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + } +}; + struct HttpClientSettings { - std::string LogCategory = "httpclient"; - std::chrono::milliseconds ConnectTimeout{3000}; - std::chrono::milliseconds Timeout{}; - bool AssumeHttp2 = false; + std::string LogCategory = "httpclient"; + std::chrono::milliseconds ConnectTimeout{3000}; + std::chrono::milliseconds Timeout{}; + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + bool AssumeHttp2 = false; + bool AllowResume = false; + uint8_t RetryCount = 0; + Oid SessionId = Oid::Zero; }; -class HttpClient +class HttpClientError : public std::runtime_error { public: - struct Settings + using _Mybase = runtime_error; + + HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode) + : _Mybase(Message) + , m_Error(Error) + , m_ResponseCode(ResponseCode) { + } + + HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode) + : _Mybase(Message) + , m_Error(Error) + , m_ResponseCode(ResponseCode) + { + } + + inline int GetInternalErrorCode() const { return m_Error; } + inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; } + + enum class ResponseClass : std::int8_t + { + kSuccess = 0, + + kHttpOtherClientError = 80, + kHttpCantConnectError = 81, // CONNECTION_FAILURE + kHttpNotFound = 66, // NotFound(404) + kHttpUnauthorized = 77, // Unauthorized(401), + kHttpSLLError = + 82, // SSL_CONNECT_ERROR, SSL_LOCAL_CERTIFICATE_ERROR, SSL_REMOTE_CERTIFICATE_ERROR, SSL_CACERT_ERROR, GENERIC_SSL_ERROR + kHttpForbidden = 83, // Forbidden(403) + kHttpTimeout = 84, // NETWORK_RECEIVE_ERROR, NETWORK_SEND_FAILURE, OPERATION_TIMEDOUT, RequestTimeout(408) + kHttpConflict = 85, // Conflict(409) + kHttpNoHost = 86, // HOST_RESOLUTION_FAILURE, PROXY_RESOLUTION_FAILURE + + kHttpOtherServerError = 90, + kHttpInternalServerError = 91, // InternalServerError(500) + kHttpServiceUnavailable = 69, // ServiceUnavailable(503) + kHttpBadGateway = 92, // BadGateway(502) + kHttpGatewayTimeout = 93, // GatewayTimeout(504) }; + + ResponseClass GetResponseClass() const; + +private: + const int m_Error = 0; + const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; +}; + +class HttpClientBase; + +class HttpClient +{ +public: HttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); ~HttpClient(); @@ -78,7 +151,7 @@ public: HttpResponseCode StatusCode = HttpResponseCode::ImATeapot; IoBuffer ResponsePayload; // Note: this also includes the content type - // Contains the reponse headers + // Contains the response headers KeyValueMap Header; // The number of bytes sent as part of the request @@ -121,41 +194,54 @@ public: std::string ErrorMessage(std::string_view Prefix) const; }; + static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType) + { + return std::make_pair("Accept", MapContentTypeToString(ContentType)); + } + [[nodiscard]] Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Put(std::string_view Url, const KeyValueMap& Parameters = {}); [[nodiscard]] Response Get(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, const KeyValueMap& Parameters = {}); [[nodiscard]] Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Post(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, const KeyValueMap& Parameters = {}); [[nodiscard]] Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}); - static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType) - { - return std::make_pair("Accept", MapContentTypeToString(ContentType)); - } - - LoggerRef Logger() { return m_Log; } + LoggerRef Log() { return m_Log; } std::string_view GetBaseUri() const { return m_BaseUri; } + std::string_view GetSessionId() const { return m_SessionId; } + void SetSessionId(const Oid& SessionId); + + bool Authenticate(); private: - struct Impl; + HttpClientBase* m_Inner; LoggerRef m_Log; std::string m_BaseUri; std::string m_SessionId; const HttpClientSettings m_ConnectionSettings; - Ref<Impl> m_Impl; }; void httpclient_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h new file mode 100644 index 000000000..26f31ed2a --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpclient.h> +#include <optional> + +namespace zen { + +class AuthMgr; + +namespace httpclientauth { + std::function<HttpClientAccessToken()> CreateFromStaticToken(HttpClientAccessToken Token); + + std::function<HttpClientAccessToken()> CreateFromStaticToken(std::string_view Token); + + struct OAuthClientCredentialsParams + { + std::string_view Url; + std::string_view ClientId; + std::string_view ClientSecret; + }; + + std::function<HttpClientAccessToken()> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params); + + std::function<HttpClientAccessToken()> CreateFromOpenIdProvider(AuthMgr& AuthManager, std::string_view OpenIdProvider); + std::function<HttpClientAccessToken()> CreateFromDefaultOpenIdProvider(AuthMgr& AuthManager); + + std::optional<std::function<HttpClientAccessToken()>> CreateFromOidcTokenExecutable(const std::filesystem::path& OidcExecutablePath, + std::string_view CloudHost, + bool Quiet, + bool Unattended, + bool Hidden); +} // namespace httpclientauth + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 1089dd221..03e547bf3 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -62,6 +62,8 @@ public: } }; + static std::string Decode(std::string_view PercentEncodedString); + virtual bool TryGetRanges(HttpRanges&) { return false; } QueryParams GetQueryParams(); @@ -182,12 +184,19 @@ public: virtual void Close() = 0; }; +struct HttpServerPluginConfig +{ + std::string PluginName; + std::vector<std::pair<std::string, std::string>> PluginOptions; +}; + struct HttpServerConfig { - bool IsDedicatedServer = false; // Should be set to true for shared servers - std::string ServerClass; // Choice of HTTP server implementation - bool ForceLoopback = false; - unsigned int ThreadCount = 0; + bool IsDedicatedServer = false; // Should be set to true for shared servers + std::string ServerClass; // Choice of HTTP server implementation + std::vector<HttpServerPluginConfig> PluginConfigs; + bool ForceLoopback = false; + unsigned int ThreadCount = 0; struct { @@ -206,7 +215,7 @@ class HttpRouterRequest public: HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} - ZENCORE_API std::string GetCapture(uint32_t Index) const; + std::string_view GetCapture(uint32_t Index) const; inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: @@ -218,12 +227,14 @@ private: friend class HttpRequestRouter; }; -inline std::string +inline std::string_view HttpRouterRequest::GetCapture(uint32_t Index) const { ZEN_ASSERT(Index < m_Match.size()); - return m_Match[Index]; + const auto& Match = m_Match[Index]; + + return std::string_view(&*Match.first, Match.second - Match.first); } /** HTTP request router helper diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h new file mode 100644 index 000000000..c90b840da --- /dev/null +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -0,0 +1,164 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +class IoBuffer; +class CbPackage; +class CompositeBuffer; + +/** _____ _ _____ _ + / ____| | | __ \ | | + | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___ + | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \ + | |____| |_) | | | (_| | (__| < (_| | (_| | __/ + \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___| + __/ | + |___/ + + Structures and code related to handling CbPackage transactions + + CbPackage instances are marshaled across the wire using a distinct message + format. We don't use the CbPackage serialization format provided by the + CbPackage implementation itself since that does not provide much flexibility + in how the attachment payloads are transmitted. The scheme below separates + metadata cleanly from payloads and this enables us to more efficiently + transmit them either via sendfile/TransmitFile like mechanisms, or by + reference/memory mapping in the local case. + */ + +struct CbPackageHeader +{ + uint32_t HeaderMagic; + uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document? + uint32_t Reserved1; + uint32_t Reserved2; +}; + +static_assert(sizeof(CbPackageHeader) == 16); + +enum : uint32_t +{ + kCbPkgMagic = 0xaa77aacc +}; + +struct CbAttachmentEntry +{ + uint64_t PayloadSize; // Size of the associated payload data in the message + uint32_t Flags; // See flags below + IoHash AttachmentHash; // Content Id for the attachment + + enum + { + kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format + kIsObject = (1u << 1), // Is compact binary object + kIsError = (1u << 2), // Is error (compact binary formatted) object + kIsLocalRef = (1u << 3), // Is "local reference" + }; +}; + +struct CbAttachmentReferenceHeader +{ + uint64_t PayloadByteOffset = 0; + uint64_t PayloadByteSize = ~0u; + uint16_t AbsolutePathLength = 0; + + // This header will be followed by UTF8 encoded absolute path to backing file +}; + +static_assert(sizeof(CbAttachmentEntry) == 32); + +enum class FormatFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), + kDenyPartialLocalReferences = (1u << 1) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags); + +enum class RpcAcceptOptions : uint16_t +{ + kNone = 0, + kAllowLocalReferences = (1u << 0), + kAllowPartialLocalReferences = (1u << 1), + kAllowPartialCacheChunks = (1u << 2) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + return IoBuffer{Size}; + }); +bool IsPackageMessage(IoBuffer Payload); + +bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, void* TargetProcessHandle = nullptr); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* TargetProcessHandle = nullptr); + +/** Streaming reader for compact binary packages + + The goal is to ultimately support zero-copy I/O, but for now there'll be some + copying involved on some platforms at least. + + This approach to deserializing CbPackage data is more efficient than + `ParsePackageMessage` since it does not require the entire message to + be resident in a memory buffer + + */ +class CbPackageReader +{ +public: + CbPackageReader(); + ~CbPackageReader(); + + void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + + /** Process compact binary package data stream + + The data stream must be in the serialization format produced by FormatPackageMessage + + \return How many bytes must be fed to this function in the next call + */ + uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes); + + void Finalize(); + const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; } + CbObject GetRootObject() { return m_RootObject; } + std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; } + +private: + enum class State + { + kInitialState, + kReadingHeader, + kReadingAttachmentEntries, + kReadingBuffers + } m_CurrentState = State::kInitialState; + + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; + std::vector<IoBuffer> m_PayloadBuffers; + std::vector<CbAttachmentEntry> m_AttachmentEntries; + std::vector<CbAttachment> m_Attachments; + CbObject m_RootObject; + CbPackageHeader m_PackageHeader; + + IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer); +}; + +void forcelink_packageformat(); + +} // namespace zen diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp new file mode 100644 index 000000000..708238224 --- /dev/null +++ b/src/zenhttp/packageformat.cpp @@ -0,0 +1,936 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/packageformat.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compositebuffer.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zencore/trace.h> + +#include <span> +#include <vector> + +#include <EASTL/fixed_vector.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +const std::string_view HandlePrefix(":?#:"); + +typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t; + +IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle); + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, void* TargetProcessHandle) +{ + return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessHandle); +} +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, void* TargetProcessHandle) +{ + return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessHandle); +} + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle) +{ + auto Vec = FormatPackageMessageInternal(Data, Flags, TargetProcessHandle); + return std::vector<IoBuffer>(begin(Vec), end(Vec)); +} + +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle) +{ + auto Vec = FormatPackageMessageInternal(Data, Flags, TargetProcessHandle); + return CompositeBuffer(std::span{begin(Vec), end(Vec)}); +} + +static void +MarshalLocal(CbAttachmentEntry*& AttachmentInfo, + const std::string& Path8, + CbAttachmentReferenceHeader& LocalRef, + const IoHash& AttachmentHash, + bool IsCompressed, + IoBufferVec_t& ResponseBuffers) +{ + IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size()); + + CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>(); + *RefHdr++ = LocalRef; + memcpy(RefHdr, Path8.data(), Path8.size()); + + *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(), + .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef, + .AttachmentHash = AttachmentHash}; + + ResponseBuffers.emplace_back(std::move(RefBuffer)); +}; + +static bool +IsLocalRef(tsl::robin_map<void*, std::string>& FileNameMap, + std::vector<void*>& DuplicatedHandles, + const CompositeBuffer& AttachmentBinary, + bool DenyPartialLocalReferences, + void* TargetProcessHandle, + CbAttachmentReferenceHeader& LocalRef, + std::string& Path8) +{ + const SharedBuffer& Segment = AttachmentBinary.GetSegments().front(); + IoBufferFileReference Ref; + const IoBuffer& SegmentBuffer = Segment.AsIoBuffer(); + + if (!SegmentBuffer.GetFileReference(Ref)) + { + return false; + } + + if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile()) + { + return false; + } + + if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end()) + { + Path8 = It->second; + } + else + { + bool UseFilePath = true; +#if ZEN_PLATFORM_WINDOWS + if (TargetProcessHandle != nullptr) + { + HANDLE TargetHandle = INVALID_HANDLE_VALUE; + BOOL OK = ::DuplicateHandle(GetCurrentProcess(), + Ref.FileHandle, + (HANDLE)TargetProcessHandle, + &TargetHandle, + FILE_GENERIC_READ, + FALSE, + 0); + if (OK) + { + DuplicatedHandles.push_back((void*)TargetHandle); + Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle)); + UseFilePath = false; + } + } +#else // ZEN_PLATFORM_WINDOWS + ZEN_UNUSED(TargetProcessHandle); + ZEN_UNUSED(DuplicatedHandles); + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to + // deal with access rights etc. +#endif // ZEN_PLATFORM_WINDOWS + if (UseFilePath) + { + ExtendablePathBuilder<256> LocalRefFile; + std::error_code Ec; + std::filesystem::path FilePath = PathFromHandle(Ref.FileHandle, Ec); + if (Ec) + { + ZEN_WARN("Failed to get path for file handle {} in IsLocalRef check, reason '{}'", Ref.FileHandle, Ec.message()); + return false; + } + LocalRefFile.Append(std::filesystem::absolute(FilePath)); + Path8 = LocalRefFile.ToUtf8(); + } + FileNameMap.insert_or_assign(Ref.FileHandle, Path8); + } + + LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size()); + LocalRef.PayloadByteOffset = Ref.FileChunkOffset; + LocalRef.PayloadByteSize = Ref.FileChunkSize; + + return true; +}; + +IoBufferVec_t +FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle) +{ + ZEN_TRACE_CPU("FormatPackageMessage"); + + std::vector<void*> DuplicatedHandles; +#if ZEN_PLATFORM_WINDOWS + auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() { + if (TargetProcessHandle == nullptr) + { + return; + } + + for (void* DuplicatedHandle : DuplicatedHandles) + { + HANDLE ClosingHandle; + if (::DuplicateHandle((HANDLE)TargetProcessHandle, + (HANDLE)DuplicatedHandle, + GetCurrentProcess(), + &ClosingHandle, + 0, + FALSE, + DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE) + { + ::CloseHandle(ClosingHandle); + } + } + }); +#endif // ZEN_PLATFORM_WINDOWS + + const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); + IoBufferVec_t ResponseBuffers; + + ResponseBuffers.reserve(2 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each + // attachment is likely to consist of several buffers + + IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; + MutableMemoryView HeaderView = AttachmentMetadataBuffer.GetMutableView(); + // Fixed size header + + CbPackageHeader* Hdr = (CbPackageHeader*)HeaderView.GetData(); + *Hdr = {.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; + HeaderView.MidInline(sizeof(CbPackageHeader)); + + // Attachment metadata array + CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(HeaderView.GetData()); + ResponseBuffers.emplace_back(std::move(AttachmentMetadataBuffer)); // Attachment metadata + + // Root object + + IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); + ZEN_ASSERT(RootIoBuffer.GetSize() > 0); + *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()}; + ResponseBuffers.emplace_back(std::move(RootIoBuffer)); // Root object + + // Attachment payloads + tsl::robin_map<void*, std::string> FileNameMap; + + for (const CbAttachment& Attachment : Attachments) + { + if (Attachment.IsNull()) + { + ZEN_NOT_IMPLEMENTED("Null attachments are not supported"); + } + else if (const CompressedBuffer& AttachmentBuffer = Attachment.AsCompressedBinary()) + { + const CompositeBuffer& Compressed = AttachmentBuffer.GetCompressed(); + IoHash AttachmentHash = Attachment.GetHash(); + + // If the data is either not backed by a file, or there are multiple + // fragments then we cannot marshal it by local reference. We might + // want/need to extend this in the future to allow multiple chunk + // segments to be marshaled at once + + bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(FileNameMap, + DuplicatedHandles, + Compressed, + DenyPartialLocalReferences, + TargetProcessHandle, + LocalRef, + Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = true; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(AttachmentInfo, Path8, LocalRef, AttachmentHash, IsCompressed, ResponseBuffers); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(), + .Flags = CbAttachmentEntry::kIsCompressed, + .AttachmentHash = AttachmentHash}; + + std::span<const SharedBuffer> Segments = Compressed.GetSegments(); + ResponseBuffers.reserve(ResponseBuffers.size() + Segments.size() - 1); + for (const SharedBuffer& Segment : Segments) + { + ZEN_ASSERT(Segment.GetSize() > 0); + ResponseBuffers.emplace_back(Segment.AsIoBuffer()); + } + } + } + else if (CbObject AttachmentObject = Attachment.AsObject()) + { + IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer(); + ZEN_ASSERT(ObjIoBuffer.GetSize() > 0); + *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(), + .Flags = CbAttachmentEntry::kIsObject, + .AttachmentHash = Attachment.GetHash()}; + ResponseBuffers.emplace_back(std::move(ObjIoBuffer)); + } + else if (const CompositeBuffer& AttachmentBinary = Attachment.AsCompositeBinary()) + { + IoHash AttachmentHash = Attachment.GetHash(); + bool MarshalByLocalRef = + EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(FileNameMap, + DuplicatedHandles, + AttachmentBinary, + DenyPartialLocalReferences, + TargetProcessHandle, + LocalRef, + Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = false; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(AttachmentInfo, Path8, LocalRef, AttachmentHash, IsCompressed, ResponseBuffers); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()}; + + std::span<const SharedBuffer> Segments = AttachmentBinary.GetSegments(); + ResponseBuffers.reserve(ResponseBuffers.size() + Segments.size() - 1); + for (const SharedBuffer& Segment : Segments) + { + ZEN_ASSERT(Segment.GetSize() > 0); + ResponseBuffers.emplace_back(Segment.AsIoBuffer()); + } + } + } + else + { + ZEN_NOT_IMPLEMENTED("Unknown attachment kind"); + } + } + FileNameMap.clear(); +#if ZEN_PLATFORM_WINDOWS + DuplicatedHandles.clear(); +#endif // ZEN_PLATFORM_WINDOWS + + return ResponseBuffers; +} + +bool +IsPackageMessage(IoBuffer Payload) +{ + if (Payload.GetSize() < sizeof(CbPackageHeader)) + { + return false; + } + + BinaryReader Reader(Payload); + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + if (Hdr->HeaderMagic != kCbPkgMagic) + { + return false; + } + + return true; +} + +CbPackage +ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +{ + ZEN_TRACE_CPU("ParsePackageMessage"); + + if (Payload.GetSize() < sizeof(CbPackageHeader)) + { + throw std::invalid_argument(fmt::format("invalid CbPackage, missing complete header (size {})", Payload.GetSize())); + } + + BinaryReader Reader(Payload); + + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + if (Hdr->HeaderMagic != kCbPkgMagic) + { + throw std::invalid_argument( + fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic)); + } + Reader.Skip(sizeof(CbPackageHeader)); + + const uint32_t ChunkCount = Hdr->AttachmentCount + 1; + + if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount) + { + throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)", + sizeof(CbAttachmentEntry) * ChunkCount, + Reader.Remaining())); + } + const CbAttachmentEntry* AttachmentEntries = + reinterpret_cast<const CbAttachmentEntry*>(Reader.GetView(sizeof(CbAttachmentEntry) * ChunkCount).GetData()); + Reader.Skip(sizeof(CbAttachmentEntry) * ChunkCount); + + CbPackage Package; + + std::vector<CbAttachment> Attachments; + Attachments.reserve(ChunkCount); // Guessing here... + + tsl::robin_map<std::string, IoBuffer> PartialFileBuffers; + + std::vector<std::pair<uint32_t, std::string>> MalformedAttachments; + + for (uint32_t i = 0; i < ChunkCount; ++i) + { + const CbAttachmentEntry& Entry = AttachmentEntries[i]; + const uint64_t AttachmentSize = Entry.PayloadSize; + + if (Reader.Remaining() < AttachmentSize) + { + throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment data (need {} bytes, have {} bytes)", + AttachmentSize, + Reader.Remaining())); + } + const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize); + Reader.Skip(AttachmentSize); + + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + + IoBuffer FullFileBuffer; + + std::filesystem::path Path(Utf8ToWide(PathView)); + if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end()) + { + FullFileBuffer = It->second; + } + else + { + if (PathView.starts_with(HandlePrefix)) + { +#if ZEN_PLATFORM_WINDOWS + std::string_view HandleString(PathView.substr(HandlePrefix.length())); + std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString)); + if (HandleNumber.has_value()) + { + HANDLE FileHandle = HANDLE(HandleNumber.value()); + ULARGE_INTEGER liFileSize; + liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart); + if (liFileSize.LowPart != INVALID_FILE_SIZE) + { + FullFileBuffer = + IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart), /*IsWholeFile*/ true); + PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer); + } + } +#else // ZEN_PLATFORM_WINDOWS + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes + // and to deal with acceess rights etc. + ZEN_ASSERT(false); +#endif // ZEN_PLATFORM_WINDOWS + } + else + { + FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; + } + } + + if (FullFileBuffer) + { + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + ? FullFileBuffer + : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference))); + if (CompBuf) + { + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + else + { + MalformedAttachments.push_back(std::make_pair(i, + fmt::format("Invalid format in '{}' (offset {}, size {}) for {}", + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + Entry.AttachmentHash))); + } + } + else + { + MalformedAttachments.push_back(std::make_pair(i, + fmt::format("Unable to resolve chunk at '{}' (offset {}, size {}) for {}", + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + Entry.AttachmentHash))); + } + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + CbObject AttachmentObject; + + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (!CompBuf) + { + // First payload is always a compact binary object + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Invalid format, expected compressed buffer for CbObject (size {}) for {}", + AttachmentBuffer.GetSize(), + Entry.AttachmentHash))); + } + else + { + CbValidateError ValidationError = CbValidateError::None; + AttachmentObject = ValidateAndReadCompactBinaryObject(std::move(CompBuf), ValidationError); + if (ValidationError != CbValidateError::None) + { + MalformedAttachments.push_back(std::make_pair( + i, + fmt::format("Invalid format, CbObject for {}. Reason '{}'", Entry.AttachmentHash, ToString(ValidationError)))); + } + } + + if (i == 0) + { + // First payload is always a compact binary object + Package.SetObject(AttachmentObject); + } + else + { + Attachments.emplace_back(CbAttachment(AttachmentObject, Entry.AttachmentHash)); + } + } + else + { + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (CompBuf) + { + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + else + { + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Invalid format, expected compressed buffer for attachment (size {}) for {}", + AttachmentBuffer.GetSize(), + Entry.AttachmentHash))); + } + } + } + else /* not compressed */ + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + CbValidateError ValidationError = CbValidateError::None; + CbObject AttachmentObject = ValidateAndReadCompactBinaryObject(std::move(AttachmentBuffer), ValidationError); + if (ValidationError != CbValidateError::None) + { + MalformedAttachments.push_back(std::make_pair( + i, + fmt::format("Invalid format, CbObject for {}. Reason '{}'", Entry.AttachmentHash, ToString(ValidationError)))); + } + + if (i == 0) + { + Package.SetObject(AttachmentObject); + } + else + { + Attachments.emplace_back(CbAttachment(AttachmentObject, Entry.AttachmentHash)); + } + } + else if (AttachmentSize > 0) + { + // Make a copy of the buffer so the attachments don't reference the entire payload + IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + ZEN_ASSERT(AttachmentBufferCopy); + ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); + AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); + + Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + } + else + { + MalformedAttachments.push_back( + std::make_pair(i, fmt::format("Invalid format, attachment of size zero detected for {}", Entry.AttachmentHash))); + } + } + } + PartialFileBuffers.clear(); + + Package.AddAttachments(Attachments); + + using namespace std::literals; + + if (!MalformedAttachments.empty()) + { + ExtendableStringBuilder<1024> SB; + SB << (uint64_t)MalformedAttachments.size() << " malformed attachments in package message:\n"; + for (const auto& It : MalformedAttachments) + { + SB << " #"sv << It.first << ": " << It.second << "\n"; + } + ZEN_WARN("{}", SB.ToView()); + throw std::invalid_argument(SB.ToString()); + } + + return Package; +} + +bool +ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage) +{ + if (IsPackageMessage(Response)) + { + OutPackage = ParsePackageMessage(Response); + return true; + } + return OutPackage.TryLoad(Response); +} + +CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +{ +} + +CbPackageReader::~CbPackageReader() +{ +} + +void +CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer) +{ + m_CreateBuffer = CreateBuffer; +} + +uint64_t +CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) +{ + ZEN_ASSERT(m_CurrentState != State::kReadingBuffers); + + switch (m_CurrentState) + { + case State::kInitialState: + ZEN_ASSERT(Data == nullptr); + m_CurrentState = State::kReadingHeader; + return sizeof m_PackageHeader; + + case State::kReadingHeader: + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); + return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + + case State::kReadingAttachmentEntries: + ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); + memcpy(m_AttachmentEntries.data(), Data, DataBytes); + + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + // This preallocates memory for payloads but note that for the local references + // the caller will need to handle the payload differently (i.e it's a + // CbAttachmentReferenceHeader not the actual payload) + + m_PayloadBuffers.emplace_back(IoBuffer{Entry.PayloadSize}); + } + + m_CurrentState = State::kReadingBuffers; + return 0; + + default: + ZEN_ASSERT(false); + return 0; + } +} + +IoBuffer +CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) +{ + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + + std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + + std::filesystem::path Path{PathView}; + + IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + if (!ChunkReference) + { + // Unable to open chunk reference + + throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + + return ChunkReference; +}; + +void +CbPackageReader::Finalize() +{ + if (m_AttachmentEntries.empty()) + { + return; + } + + m_Attachments.reserve(m_AttachmentEntries.size() - 1); + + int CurrentAttachmentIndex = 0; + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + + if (CurrentAttachmentIndex == 0) + { + // Root object + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + CbValidateError ValidateError = CbValidateError::None; + m_RootObject = ValidateAndReadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer), ValidateError); + if (ValidateError != CbValidateError::None) + { + throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError))); + } + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + CbValidateError ValidateError = CbValidateError::None; + m_RootObject = ValidateAndReadCompactBinaryObject(std::move(Compressed), ValidateError); + if (ValidateError != CbValidateError::None) + { + throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError))); + } + } + } + else + { + CbValidateError ValidateError = CbValidateError::None; + m_RootObject = ValidateAndReadCompactBinaryObject(std::move(AttachmentBuffer), ValidateError); + if (ValidateError != CbValidateError::None) + { + throw std::runtime_error(fmt::format("Root object format is invalid, reason: '{}'", ToString(ValidateError))); + } + } + } + else + { + throw std::runtime_error("missing or invalid root object"); + } + } + else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer); + + if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + m_Attachments.emplace_back(CbAttachment(Compressed, Entry.AttachmentHash)); + } + } + else + { + CompressedBuffer Compressed = + CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None); + m_Attachments.emplace_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash())); + } + } + + ++CurrentAttachmentIndex; + } +} + +/** + ______________________ _____________________________ + \__ ___/\_ _____// _____/\__ ___/ _____/ + | | | __)_ \_____ \ | | \_____ \ + | | | \/ \ | | / \ + |____| /_______ /_______ / |____| /_______ / + \/ \/ \/ + */ + +#if ZEN_WITH_TESTS + +TEST_CASE("CbPackage.Serialization") +{ + // Make a test package + + CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))}; + CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.EmptyObject") +{ + CbPackage Pkg; + Pkg.SetObject({}); + std::vector<IoBuffer> Result = FormatPackageMessage(Pkg, nullptr); +} + +TEST_CASE("CbPackage.LocalRef") +{ + ScopedTemporaryDirectory TempDir; + + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + { + IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd")); + IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh")); + + WriteFile(Path1, Buffer1); + WriteFile(Path2, Buffer2); + } + + // Make a test package + + IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1); + IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2); + + CbAttachment Attach1{SharedBuffer(FileBuffer1)}; + CbAttachment Attach2{SharedBuffer(FileBuffer2)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +void +forcelink_packageformat() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 9fca314b3..2023b6d98 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -1,9 +1,11 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpasio.h" +#include "httptracer.h" #include <zencore/except.h> #include <zencore/logging.h> +#include <zencore/memory/llm.h> #include <zencore/thread.h> #include <zencore/trace.h> #include <zenhttp/httpserver.h> @@ -31,6 +33,18 @@ ZEN_THIRD_PARTY_INCLUDES_END # define ZEN_TRACE_VERBOSE(fmtstr, ...) #endif +namespace zen { + +const FLLMTag& +GetHttpasioTag() +{ + static FLLMTag _("httpasio"); + + return _; +} + +} // namespace zen + namespace zen::asio_http { using namespace std::literals; @@ -62,6 +76,7 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); + void Initialize(std::filesystem::path DataDir); int Start(uint16_t Port, bool ForceLooopback, int ThreadCount); void Stop(); void RegisterService(const char* UrlPath, HttpService& Service); @@ -72,6 +87,9 @@ public: std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; std::vector<std::thread> m_ThreadPool; + LoggerRef m_RequestLog; + HttpServerTracer m_RequestTracer; + struct ServiceEntry { std::string ServiceUrlPath; @@ -120,6 +138,8 @@ public: void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_TRACE_CPU("asio::InitializeForPayload"); m_ResponseCode = ResponseCode; @@ -168,8 +188,8 @@ public: } m_ContentLength = LocalDataSize; - auto Headers = GetHeaders(); - m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + std::string_view Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); } uint16_t ResponseCode() const { return m_ResponseCode; } @@ -179,6 +199,8 @@ public: std::string_view GetHeaders() { + ZEN_MEMSCOPE(GetHttpasioTag()); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; @@ -293,7 +315,9 @@ HttpServerConnection::TerminateConnection() void HttpServerConnection::EnqueueRead() { - if (m_RequestState == RequestState::kInitialRead) + ZEN_MEMSCOPE(GetHttpasioTag()); + + if ((m_RequestState == RequestState::kInitialRead) || (m_RequestState == RequestState::kReadingMore)) { m_RequestState = RequestState::kReadingMore; } @@ -313,17 +337,21 @@ HttpServerConnection::EnqueueRead() void HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { + ZEN_MEMSCOPE(GetHttpasioTag()); + if (Ec) { - if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + switch (m_RequestState) { - ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); - return; - } - else - { - ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); - return TerminateConnection(); + case RequestState::kDone: + case RequestState::kInitialRead: + case RequestState::kTerminated: + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); + return; + + default: + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + return TerminateConnection(); } } @@ -362,6 +390,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] void HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) { + ZEN_MEMSCOPE(GetHttpasioTag()); + if (Ec) { ZEN_WARN("on data sent ERROR, connection: {}, reason: '{}'", m_ConnectionId, Ec.message()); @@ -395,6 +425,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unu void HttpServerConnection::CloseConnection() { + ZEN_MEMSCOPE(GetHttpasioTag()); + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) { return; @@ -418,6 +450,8 @@ HttpServerConnection::CloseConnection() void HttpServerConnection::HandleRequest() { + ZEN_MEMSCOPE(GetHttpasioTag()); + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; @@ -439,9 +473,29 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); + const uint32_t RequestNumber = m_RequestCounter.load(std::memory_order_relaxed); + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); - ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); + + const HttpVerb RequestVerb = Request.RequestVerb(); + const std::string_view Uri = Request.RelativeUri(); + + if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace)) + { + ZEN_LOG_TRACE(m_Server.m_RequestLog, + "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", + m_ConnectionId, + ToString(RequestVerb), + Uri, + Request.ContentLength(), + ToString(Request.RequestContentType()), + ToString(Request.AcceptContentType())); + + m_Server.m_RequestTracer.WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber), + std::vector<IoBuffer>{Request.ReadPayload()}); + } if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { @@ -449,7 +503,15 @@ HttpServerConnection::HandleRequest() { Service->HandleRequest(Request); } - catch (std::system_error& SystemError) + catch (const AssertException& AssertEx) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); + } + catch (const std::system_error& SystemError) { // Drop any partially formatted response Request.m_Response.reset(); @@ -460,23 +522,25 @@ HttpServerConnection::HandleRequest() } else { - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); } } - catch (std::bad_alloc& BadAlloc) + catch (const std::bad_alloc& BadAlloc) { // Drop any partially formatted response Request.m_Response.reset(); Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); } - catch (std::exception& ex) + catch (const std::exception& ex) { // Drop any partially formatted response Request.m_Response.reset(); - ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); } } @@ -490,11 +554,11 @@ HttpServerConnection::HandleRequest() Response->SuppressPayload(); } - auto ResponseBuffers = Response->AsioBuffers(); + const std::vector<asio::const_buffer>& ResponseBuffers = Response->AsioBuffers(); uint64_t ResponseLength = 0; - for (auto& Buffer : ResponseBuffers) + for (const asio::const_buffer& Buffer : ResponseBuffers) { ResponseLength += Buffer.size(); } @@ -573,29 +637,49 @@ struct HttpAcceptor : m_Server(Server) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + , m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4()) { m_Acceptor.set_option(asio::ip::v6_only(false)); #if ZEN_PLATFORM_WINDOWS // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms - typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; - m_Acceptor.set_option(excluse_address(true)); + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address; + m_Acceptor.set_option(exclusive_address(true)); + m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); #else // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::socket_base::reuse_address(false)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false)); #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + asio::ip::address_v6 BindAddress = ForceLoopback ? asio::ip::address_v6::loopback() : asio::ip::address_v6::any(); uint16_t EffectivePort = BasePort; + if (BindAddress.is_loopback()) + { + m_Acceptor.set_option(asio::ip::v6_only(true)); + } + asio::error_code BindErrorCode; m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback()) { // Access denied for a public port - lets try fall back to local port only BindAddress = asio::ip::address_v6::loopback(); + m_Acceptor.set_option(asio::ip::v6_only(true)); + m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::address_in_use) + { + // Do a retry after a short sleep on same port just to be sure + ZEN_INFO("Desired port {} is in use, retrying", BasePort); + Sleep(100); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } // Sharing violation implies the port is being used by another process @@ -613,9 +697,20 @@ struct HttpAcceptor { ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); } - else if (BindAddress.is_loopback()) + else { - ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", "[::1]", EffectivePort); + if (EffectivePort != BasePort) + { + ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort); + } + if (BindAddress.is_loopback()) + { + m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode); + m_UseAlternateProtocolAcceptor = true; + ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", + "localhost", + EffectivePort); + } } #if ZEN_PLATFORM_WINDOWS @@ -635,31 +730,66 @@ struct HttpAcceptor &OptionNumberOfBytesReturned, 0, 0); + if (m_UseAlternateProtocolAcceptor) + { + NativeSocket = m_AlternateProtocolAcceptor.native_handle(); + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); + } #endif m_Acceptor.listen(); + if (m_UseAlternateProtocolAcceptor) + { + m_AlternateProtocolAcceptor.listen(); + } ZEN_INFO("Started asio server at 'http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort); } + ~HttpAcceptor() + { + m_Acceptor.close(); + if (m_UseAlternateProtocolAcceptor) + { + m_AlternateProtocolAcceptor.close(); + } + } + void Start() { - m_Acceptor.listen(); - InitAccept(); + ZEN_MEMSCOPE(GetHttpasioTag()); + + ZEN_ASSERT(!m_IsStopped); + InitAcceptInternal(m_Acceptor); + if (m_UseAlternateProtocolAcceptor) + { + InitAcceptInternal(m_AlternateProtocolAcceptor); + } } - void Stop() { m_IsStopped = true; } + void StopAccepting() { m_IsStopped = true; } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } - void InitAccept() +private: + void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor) { auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); - m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - m_Acceptor.local_endpoint().address().to_string(), - m_Acceptor.local_endpoint().port(), + Acceptor.local_endpoint().address().to_string(), + Acceptor.local_endpoint().port(), Ec.message()); } else @@ -679,12 +809,12 @@ struct HttpAcceptor if (!m_IsStopped.load()) { - InitAccept(); + InitAcceptInternal(Acceptor); } else { std::error_code CloseEc; - m_Acceptor.close(CloseEc); + Acceptor.close(CloseEc); if (CloseEc) { ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); @@ -693,12 +823,11 @@ struct HttpAcceptor }); } - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } - -private: HttpAsioServerImpl& m_Server; asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; + asio::ip::tcp::acceptor m_AlternateProtocolAcceptor; + bool m_UseAlternateProtocolAcceptor{false}; std::atomic<bool> m_IsStopped{false}; }; @@ -787,6 +916,8 @@ HttpAsioServerRequest::ReadPayload() void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(HttpContentType::kBinary)); @@ -798,6 +929,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType)); @@ -807,6 +940,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType)); @@ -819,6 +954,8 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT void HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_ASSERT(!m_Response); // Not one bit async, innit @@ -833,7 +970,7 @@ HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) ////////////////////////////////////////////////////////////////////////// -HttpAsioServerImpl::HttpAsioServerImpl() +HttpAsioServerImpl::HttpAsioServerImpl() : m_RequestLog(logging::Get("http_requests")) { } @@ -841,9 +978,17 @@ HttpAsioServerImpl::~HttpAsioServerImpl() { } +void +HttpAsioServerImpl::Initialize(std::filesystem::path DataDir) +{ + m_RequestTracer.Initialize(DataDir); +} + int HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount) { + ZEN_MEMSCOPE(GetHttpasioTag()); + ZEN_ASSERT(ThreadCount > 0); ZEN_INFO("starting asio http with {} service threads", ThreadCount); @@ -863,13 +1008,19 @@ HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount) for (int i = 0; i < ThreadCount; ++i) { m_ThreadPool.emplace_back([this, Index = i + 1] { + ZEN_MEMSCOPE(GetHttpasioTag()); + SetCurrentThreadName(fmt::format("asio_io_{}", Index)); try { m_IoService.run(); } - catch (std::exception& e) + catch (const AssertException& AssertEx) + { + ZEN_ERROR("Assert caught in asio event loop: {}", AssertEx.FullDescription()); + } + catch (const std::exception& e) { ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); } @@ -884,17 +1035,29 @@ HttpAsioServerImpl::Start(uint16_t Port, bool ForceLooopback, int ThreadCount) void HttpAsioServerImpl::Stop() { - m_Acceptor->Stop(); + ZEN_MEMSCOPE(GetHttpasioTag()); + + if (m_Acceptor) + { + m_Acceptor->StopAccepting(); + } m_IoService.stop(); for (auto& Thread : m_ThreadPool) { - Thread.join(); + if (Thread.joinable()) + { + Thread.join(); + } } + m_ThreadPool.clear(); + m_Acceptor.reset(); } void HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) { + ZEN_MEMSCOPE(GetHttpasioTag()); + std::string_view UrlPath(InUrlPath); Service.SetUriPrefixLength(UrlPath.size()); if (!UrlPath.empty() && UrlPath.back() == '/') @@ -909,6 +1072,8 @@ HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) HttpService* HttpAsioServerImpl::RouteRequest(std::string_view Url) { + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::SharedLockScope _(m_Lock); HttpService* CandidateService = nullptr; @@ -978,7 +1143,7 @@ HttpAsioServer::Close() { m_Impl->Stop(); } - catch (std::exception& ex) + catch (const std::exception& ex) { ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); } @@ -994,8 +1159,11 @@ HttpAsioServer::RegisterService(HttpService& Service) int HttpAsioServer::Initialize(int BasePort, std::filesystem::path DataDir) { - ZEN_UNUSED(DataDir); + ZEN_TRACE_CPU("HttpAsioServer::Initialize"); + m_Impl->Initialize(DataDir); + m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), m_ForceLoopback, m_ThreadCount); + return m_BasePort; } @@ -1052,6 +1220,9 @@ HttpAsioServer::RequestExit() Ref<HttpServer> CreateHttpAsioServer(bool ForceLoopback, unsigned int ThreadCount) { + ZEN_TRACE_CPU("CreateHttpAsioServer"); + ZEN_MEMSCOPE(GetHttpasioTag()); + return Ref<HttpServer>{new HttpAsioServer(ForceLoopback, ThreadCount)}; } diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 2a6a90d2e..b8b7931a9 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -3,6 +3,7 @@ #include "httpmulti.h" #include <zencore/logging.h> +#include <zencore/trace.h> #if ZEN_PLATFORM_WINDOWS # include <conio.h> @@ -30,6 +31,8 @@ HttpMultiServer::RegisterService(HttpService& Service) int HttpMultiServer::Initialize(int BasePort, std::filesystem::path DataDir) { + ZEN_TRACE_CPU("HttpMultiServer::Initialize"); + ZEN_UNUSED(DataDir); ZEN_ASSERT(!m_IsInitialized); @@ -103,6 +106,10 @@ HttpMultiServer::RequestExit() void HttpMultiServer::Close() { + for (auto& Server : m_Servers) + { + Server->Close(); + } } void diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index c64134c95..93094e21b 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -6,6 +6,8 @@ #include <zencore/logging.h> #include <zencore/string.h> +#include <limits> + namespace zen { using namespace std::literals; @@ -69,23 +71,21 @@ HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) int HttpRequestParser::OnUrl(const char* Data, size_t Bytes) { - if (!m_Url) + const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size(); + if (RemainingBufferSpace < Bytes) { - ZEN_ASSERT_SLOW(m_UrlLength == 0); - m_Url = m_HeaderCursor; + ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; } - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - - if (RemainingBufferSpace < Bytes) + if (m_UrlRange.Length == 0) { - ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + ZEN_ASSERT_SLOW(m_UrlRange.Offset == 0); + m_UrlRange.Offset = (uint32_t)m_HeaderData.size(); } - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_UrlLength += Bytes; + m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); + m_UrlRange.Length += (uint32_t)Bytes; return 0; } @@ -93,56 +93,70 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes) int HttpRequestParser::OnHeader(const char* Data, size_t Bytes) { - if (m_CurrentHeaderValueLength) + const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size(); + if (RemainingBufferSpace < Bytes) { - AppendCurrentHeader(); + ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = m_HeaderCursor; + if (m_HeaderEntries.empty()) + { + m_HeaderEntries.resize(1); } - else if (m_CurrentHeaderName == nullptr) + HeaderEntry* CurrentHeaderEntry = &m_HeaderEntries.back(); + if (CurrentHeaderEntry->ValueRange.Length) { - m_CurrentHeaderName = m_HeaderCursor; + ParseCurrentHeader(); + m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}}); + CurrentHeaderEntry = &m_HeaderEntries.back(); } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - if (RemainingBufferSpace < Bytes) + else if (CurrentHeaderEntry->NameRange.Length == 0) { - ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + m_HeaderEntries.emplace_back(HeaderEntry{.NameRange = {.Offset = (uint32_t)m_HeaderData.size()}}); + CurrentHeaderEntry = &m_HeaderEntries.back(); } - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderNameLength += Bytes; + m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); + CurrentHeaderEntry->NameRange.Length += (uint32_t)Bytes; return 0; } void -HttpRequestParser::AppendCurrentHeader() +HttpRequestParser::ParseCurrentHeader() { - std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); - std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); + const HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); + const size_t CurrentHeaderCount = m_HeaderEntries.size(); + const std::string_view HeaderName(GetHeaderSubString(CurrentHeaderEntry.NameRange)); + if (CurrentHeaderCount > std::numeric_limits<int8_t>::max()) + { + ZEN_WARN("HttpRequestParser parser only supports up to {} headers, can't store header '{}'. Dropping it.", + std::numeric_limits<int8_t>::max(), + HeaderName); + return; + } + const std::string_view HeaderValue(GetHeaderSubString(CurrentHeaderEntry.ValueRange)); - const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); if (HeaderHash == HashContentLength) { - m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + m_ContentLengthHeaderIndex = CurrentHeaderIndex; } else if (HeaderHash == HashAccept) { - m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + m_AcceptHeaderIndex = CurrentHeaderIndex; } else if (HeaderHash == HashContentType) { - m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + m_ContentTypeHeaderIndex = CurrentHeaderIndex; } else if (HeaderHash == HashSession) { - m_SessionId = Oid::FromHexString(HeaderValue); + m_SessionId = Oid::TryFromHexString(HeaderValue); } else if (HeaderHash == HashRequest) { @@ -162,38 +176,38 @@ HttpRequestParser::AppendCurrentHeader() } else if (HeaderHash == HashRange) { - m_RangeHeaderIndex = (int8_t)m_Headers.size(); + m_RangeHeaderIndex = CurrentHeaderIndex; } - - m_Headers.emplace_back(HeaderName, HeaderValue); } int HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) { - if (m_CurrentHeaderValueLength == 0) - { - m_CurrentHeaderValue = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + const size_t RemainingBufferSpace = std::numeric_limits<std::uint32_t>::max() - m_HeaderData.size(); if (RemainingBufferSpace < Bytes) { - ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); + ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); return 1; } - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderValueLength += Bytes; + ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); + HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); + if (CurrentHeaderEntry.ValueRange.Length == 0) + { + CurrentHeaderEntry.ValueRange.Offset = (uint32_t)m_HeaderData.size(); + } + m_HeaderData.insert(m_HeaderData.end(), Data, &Data[Bytes]); + CurrentHeaderEntry.ValueRange.Length += (uint32_t)Bytes; return 0; } static void -NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) +NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) { - bool LastCharWasSeparator = false; + bool LastCharWasSeparator = false; + const char* Url = InUrl.data(); + const size_t UrlLength = InUrl.length(); for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) { const char UrlChar = Url[UrlIndex]; @@ -226,9 +240,13 @@ HttpRequestParser::OnHeadersComplete() { try { - if (m_CurrentHeaderValueLength) + if (!m_HeaderEntries.empty()) { - AppendCurrentHeader(); + HeaderEntry& CurrentHeaderEntry = m_HeaderEntries.back(); + if (CurrentHeaderEntry.NameRange.Length) + { + ParseCurrentHeader(); + } } m_KeepAlive = !!http_should_keep_alive(&m_Parser); @@ -268,21 +286,21 @@ HttpRequestParser::OnHeadersComplete() break; } - std::string_view Url(m_Url, m_UrlLength); + std::string_view FullUrl(GetHeaderSubString(m_UrlRange)); - if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) + if (auto QuerySplit = FullUrl.find_first_of('?'); QuerySplit != std::string_view::npos) { - m_UrlLength = QuerySplit; - m_QueryString = m_Url + QuerySplit + 1; - m_QueryLength = Url.size() - QuerySplit - 1; + m_UrlRange.Length = uint32_t(QuerySplit); + m_QueryStringRange = {.Offset = uint32_t(m_UrlRange.Offset + QuerySplit + 1), + .Length = uint32_t(FullUrl.size() - QuerySplit - 1)}; } - NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); + NormalizeUrlPath(FullUrl, m_NormalizedUrl); - if (m_ContentLengthHeaderIndex >= 0) + std::string_view Value = GetHeaderValue(m_ContentLengthHeaderIndex); + if (!Value.empty()) { - std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; - uint64_t ContentLength = 0; + uint64_t ContentLength = 0; std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); if (ContentLength) @@ -330,16 +348,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) void HttpRequestParser::ResetState() { - m_HeaderCursor = m_HeaderBuffer; - m_CurrentHeaderName = nullptr; - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValue = nullptr; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = nullptr; - m_Url = nullptr; - m_UrlLength = 0; - m_QueryString = nullptr; - m_QueryLength = 0; + m_UrlRange = {}; + m_QueryStringRange = {}; + + m_HeaderEntries.clear(); + m_ContentLengthHeaderIndex = -1; m_AcceptHeaderIndex = -1; m_ContentTypeHeaderIndex = -1; @@ -347,7 +360,8 @@ HttpRequestParser::ResetState() m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; - m_Headers.clear(); + + m_HeaderData.clear(); m_NormalizedUrl.clear(); } @@ -366,7 +380,12 @@ HttpRequestParser::OnMessageComplete() ResetState(); return 0; } - catch (std::system_error& SystemError) + catch (const AssertException& AssertEx) + { + ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); + return 1; + } + catch (const std::system_error& SystemError) { if (IsOOM(SystemError.code())) { @@ -378,18 +397,18 @@ HttpRequestParser::OnMessageComplete() } else { - ZEN_ERROR("failed processing http request: '{}'", SystemError.what()); + ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); return 1; } - catch (std::bad_alloc& BadAlloc) + catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); return 1; } - catch (std::exception& Ex) + catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index bdbcab4d9..0d2664ec5 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -5,6 +5,8 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <EASTL/fixed_vector.h> + ZEN_THIRD_PARTY_INCLUDES_START #include <http_parser.h> ZEN_THIRD_PARTY_INCLUDES_END @@ -31,73 +33,68 @@ struct HttpRequestParser HttpVerb RequestVerb() const { return m_RequestVerb; } bool IsKeepAlive() const { return m_KeepAlive; } - std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } - std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } + std::string_view Url() const { return m_NormalizedUrl.empty() ? GetHeaderSubString(m_UrlRange) : m_NormalizedUrl; } + std::string_view QueryString() const { return GetHeaderSubString(m_QueryStringRange); } IoBuffer Body() { return m_BodyBuffer; } - inline HttpContentType ContentType() - { - if (m_ContentTypeHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); - } + inline HttpContentType ContentType() { return ParseContentType(GetHeaderValue(m_ContentTypeHeaderIndex)); } - inline HttpContentType AcceptType() - { - if (m_AcceptHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); - } + inline HttpContentType AcceptType() { return ParseContentType(GetHeaderValue(m_AcceptHeaderIndex)); } Oid SessionId() const { return m_SessionId; } int RequestId() const { return m_RequestId; } - std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } + std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); } private: + struct HeaderRange + { + uint32_t Offset = 0; + uint32_t Length = 0; + }; + struct HeaderEntry { - HeaderEntry() = default; + HeaderRange NameRange; + HeaderRange ValueRange; + }; - HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} + inline std::string_view GetHeaderValue(int8_t HeaderIndex) const + { + if (HeaderIndex == -1) + { + return {}; + } + ZEN_ASSERT(size_t(HeaderIndex) < m_HeaderEntries.size()); + return GetHeaderSubString(m_HeaderEntries[HeaderIndex].ValueRange); + } - std::string_view Name; - std::string_view Value; - }; + std::string_view GetHeaderSubString(const HeaderRange& Range) const + { + ZEN_ASSERT_SLOW(Range.Offset + Range.Length <= m_HeaderData.size()); + return std::string_view(m_HeaderData.begin(), m_HeaderData.size()).substr(Range.Offset, Range.Length); + } - HttpRequestParserCallbacks& m_Connection; - char* m_HeaderCursor = m_HeaderBuffer; - char* m_Url = nullptr; - size_t m_UrlLength = 0; - char* m_QueryString = nullptr; - size_t m_QueryLength = 0; - char* m_CurrentHeaderName = nullptr; // Used while parsing headers - size_t m_CurrentHeaderNameLength = 0; - char* m_CurrentHeaderValue = nullptr; // Used while parsing headers - size_t m_CurrentHeaderValueLength = 0; - std::vector<HeaderEntry> m_Headers; - int8_t m_ContentLengthHeaderIndex; - int8_t m_AcceptHeaderIndex; - int8_t m_ContentTypeHeaderIndex; - int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; - std::atomic_bool m_KeepAlive{false}; - bool m_Expect100Continue = false; - int m_RequestId = -1; - Oid m_SessionId{}; - IoBuffer m_BodyBuffer; - uint64_t m_BodyPosition = 0; - http_parser m_Parser; - char m_HeaderBuffer[1024]; - std::string m_NormalizedUrl; - - void AppendCurrentHeader(); + HttpRequestParserCallbacks& m_Connection; + HeaderRange m_UrlRange; + HeaderRange m_QueryStringRange; + eastl::fixed_vector<HeaderEntry, 16> m_HeaderEntries; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + std::atomic_bool m_KeepAlive{false}; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + eastl::fixed_vector<char, 512> m_HeaderData; + std::string m_NormalizedUrl; + + void ParseCurrentHeader(); int OnMessageBegin(); int OnUrl(const char* Data, size_t Bytes); diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 3eed9db8f..d6ca7e1c5 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -2,6 +2,8 @@ #include <zenhttp/httpplugin.h> +#include "httptracer.h" + #if ZEN_WITH_PLUGINS # include "httpparser.h" @@ -10,6 +12,7 @@ # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/logging.h> +# include <zencore/memory/llm.h> # include <zencore/scopeguard.h> # include <zencore/session.h> # include <zencore/thread.h> @@ -25,14 +28,6 @@ # include <conio.h> # endif -# define PLUGIN_VERBOSE_TRACE 1 - -# if PLUGIN_VERBOSE_TRACE -# define ZEN_TRACE_VERBOSE ZEN_TRACE -# else -# define ZEN_TRACE_VERBOSE(fmtstr, ...) -# endif - namespace zen { struct HttpPluginServerImpl; @@ -40,6 +35,14 @@ struct HttpPluginResponse; using namespace std::literals; +const FLLMTag& +GetHttppluginTag() +{ + static FLLMTag _("httpplugin"); + + return _; +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted @@ -103,8 +106,6 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* RouteRequest(std::string_view Url); - void WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload); - struct ServiceEntry { std::string ServiceUrlPath; @@ -119,8 +120,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer bool m_IsRequestLoggingEnabled = false; LoggerRef m_RequestLog; std::atomic_uint32_t m_ConnectionIdCounter{0}; - std::filesystem::path m_DataDir; // Application data directory - std::filesystem::path m_PayloadDir; // Request debugging payload directory + + HttpServerTracer m_RequestTracer; // TransportServer @@ -191,6 +192,7 @@ private: void HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) { + ZEN_MEMSCOPE(GetHttppluginTag()); ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); m_ResponseCode = ResponseCode; @@ -232,6 +234,8 @@ HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuff std::string_view HttpPluginResponse::GetHeaders() { + ZEN_MEMSCOPE(GetHttppluginTag()); + if (m_Headers.Size() == 0) { m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" @@ -266,6 +270,8 @@ HttpPluginConnectionHandler::~HttpPluginConnectionHandler() void HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server, uint32_t ConnectionId) { + ZEN_MEMSCOPE(GetHttppluginTag()); + m_TransportConnection = Transport; m_Server = &Server; m_ConnectionId = ConnectionId; @@ -298,6 +304,8 @@ HttpPluginConnectionHandler::Release() const void HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) { + ZEN_MEMSCOPE(GetHttppluginTag()); + ZEN_ASSERT(m_Server); ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} OnBytesRead: {}", m_ConnectionId, AvailableBytes); @@ -325,6 +333,8 @@ HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableByt void HttpPluginConnectionHandler::HandleRequest() { + ZEN_MEMSCOPE(GetHttppluginTag()); + ZEN_ASSERT(m_Server); const uint32_t RequestNumber = m_RequestCounter.fetch_add(1); @@ -376,8 +386,8 @@ HttpPluginConnectionHandler::HandleRequest() ToString(Request.RequestContentType()), ToString(Request.AcceptContentType())); - m_Server->WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber), - std::vector<IoBuffer>{Request.ReadPayload()}); + m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("request_{}_{}.bin", m_ConnectionId, RequestNumber), + std::vector<IoBuffer>{Request.ReadPayload()}); } if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) @@ -386,7 +396,15 @@ HttpPluginConnectionHandler::HandleRequest() { Service->HandleRequest(Request); } - catch (std::system_error& SystemError) + catch (const AssertException& AssertEx) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); + } + catch (const std::system_error& SystemError) { // Drop any partially formatted response Request.m_Response.reset(); @@ -397,23 +415,25 @@ HttpPluginConnectionHandler::HandleRequest() } else { - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); } } - catch (std::bad_alloc& BadAlloc) + catch (const std::bad_alloc& BadAlloc) { // Drop any partially formatted response Request.m_Response.reset(); Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); } - catch (std::exception& ex) + catch (const std::exception& ex) { // Drop any partially formatted response Request.m_Response.reset(); - ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); } } @@ -442,7 +462,8 @@ HttpPluginConnectionHandler::HandleRequest() if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) { - m_Server->WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers); + m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), + ResponseBuffers); } for (const IoBuffer& Buffer : ResponseBuffers) @@ -523,6 +544,7 @@ HttpPluginConnectionHandler::HandleRequest() void HttpPluginConnectionHandler::TerminateConnection() { + ZEN_MEMSCOPE(GetHttppluginTag()); ZEN_ASSERT(m_TransportConnection); m_TransportConnection->CloseConnection(); } @@ -533,6 +555,8 @@ HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, Htt : m_Request(Request) , m_PayloadBuffer(std::move(PayloadBuffer)) { + ZEN_MEMSCOPE(GetHttppluginTag()); + const int PrefixLength = Service.UriPrefixLength(); std::string_view Uri = Request.Url(); @@ -613,6 +637,7 @@ void HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) { ZEN_ASSERT(!m_Response); + ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); std::array<IoBuffer, 0> Empty; @@ -624,6 +649,7 @@ void HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) { ZEN_ASSERT(!m_Response); + ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(ContentType)); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); @@ -633,6 +659,8 @@ void HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { ZEN_ASSERT(!m_Response); + ZEN_MEMSCOPE(GetHttppluginTag()); + m_Response.reset(new HttpPluginResponse(ContentType)); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); @@ -645,6 +673,7 @@ void HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) { ZEN_ASSERT(!m_Response); + ZEN_MEMSCOPE(GetHttppluginTag()); // Not one bit async, innit ContinuationHandler(*this); @@ -669,6 +698,7 @@ HttpPluginServerImpl::~HttpPluginServerImpl() TransportServerConnection* HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) { + ZEN_MEMSCOPE(GetHttppluginTag()); HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; const uint32_t ConnectionId = m_ConnectionIdCounter.fetch_add(1); Handler->Initialize(Connection, *this, ConnectionId); @@ -678,10 +708,10 @@ HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) int HttpPluginServerImpl::Initialize(int BasePort, std::filesystem::path DataDir) { - m_DataDir = DataDir; - m_PayloadDir = DataDir / "debug" / GetSessionIdString(); + ZEN_TRACE_CPU("HttpPluginServerImpl::Initialize"); - ZEN_INFO("any debug payloads will be written to '{}'", m_PayloadDir); + ZEN_MEMSCOPE(GetHttppluginTag()); + m_RequestTracer.Initialize(DataDir); try { @@ -693,13 +723,13 @@ HttpPluginServerImpl::Initialize(int BasePort, std::filesystem::path DataDir) { Plugin->Initialize(this); } - catch (std::exception& Ex) + catch (const std::exception& Ex) { ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); } } } - catch (std::exception& ex) + catch (const std::exception& ex) { ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); } @@ -715,6 +745,8 @@ HttpPluginServerImpl::Close() if (!m_IsInitialized) return; + ZEN_MEMSCOPE(GetHttppluginTag()); + try { RwLock::ExclusiveLockScope _(m_Lock); @@ -725,7 +757,7 @@ HttpPluginServerImpl::Close() { Plugin->Shutdown(); } - catch (std::exception& Ex) + catch (const std::exception& Ex) { ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); } @@ -735,7 +767,7 @@ HttpPluginServerImpl::Close() m_Plugins.clear(); } - catch (std::exception& ex) + catch (const std::exception& ex) { ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); } @@ -746,6 +778,8 @@ HttpPluginServerImpl::Close() void HttpPluginServerImpl::Run(bool IsInteractive) { + ZEN_MEMSCOPE(GetHttppluginTag()); + const bool TestMode = !IsInteractive; int WaitTimeout = -1; @@ -796,6 +830,8 @@ HttpPluginServerImpl::RequestExit() void HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin) { + ZEN_MEMSCOPE(GetHttppluginTag()); + RwLock::ExclusiveLockScope _(m_Lock); m_Plugins.emplace_back(std::move(Plugin)); } @@ -803,6 +839,8 @@ HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin) void HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin) { + ZEN_MEMSCOPE(GetHttppluginTag()); + RwLock::ExclusiveLockScope _(m_Lock); auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); if (It != m_Plugins.end()) @@ -814,6 +852,8 @@ HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin) void HttpPluginServerImpl::RegisterService(HttpService& Service) { + ZEN_MEMSCOPE(GetHttppluginTag()); + std::string_view UrlPath(Service.BaseUri()); Service.SetUriPrefixLength(UrlPath.size()); @@ -829,6 +869,8 @@ HttpPluginServerImpl::RegisterService(HttpService& Service) HttpService* HttpPluginServerImpl::RouteRequest(std::string_view Url) { + ZEN_MEMSCOPE(GetHttppluginTag()); + RwLock::SharedLockScope _(m_Lock); HttpService* CandidateService = nullptr; @@ -848,23 +890,6 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } -void -HttpPluginServerImpl::WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload) -{ - uint64_t PayloadSize = 0; - std::vector<const IoBuffer*> Buffers; - for (auto& Io : Payload) - { - Buffers.push_back(&Io); - PayloadSize += Io.GetSize(); - } - - if (PayloadSize) - { - WriteFile(m_PayloadDir / Filename, Buffers.data(), Buffers.size()); - } -} - ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; @@ -872,6 +897,7 @@ struct HttpPluginServerImpl; Ref<HttpPluginServer> CreateHttpPluginServer() { + ZEN_MEMSCOPE(GetHttppluginTag()); return Ref<HttpPluginServer>(new HttpPluginServerImpl); } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 5cd273c40..95d83911d 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -9,11 +9,14 @@ #include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> #include <zencore/timer.h> #include <zencore/trace.h> -#include <zenutil/packageformat.h> +#include <zenhttp/packageformat.h> + +#include <EASTL/fixed_vector.h> #if ZEN_WITH_HTTPSYS # define _WINSOCKAPI_ @@ -25,6 +28,14 @@ namespace zen { +const FLLMTag& +GetHttpsysTag() +{ + static FLLMTag HttpsysTag("httpsys"); + + return HttpsysTag; +} + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -372,14 +383,14 @@ public: void SuppressResponseBody(); // typically used for HEAD requests private: - std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; - uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes - uint16_t m_ResponseCode = 0; - uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists - uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends - bool m_IsInitialResponse = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - std::vector<IoBuffer> m_DataBuffers; + eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -524,7 +535,14 @@ HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfB if (IoResult != NO_ERROR) { - ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + ZEN_WARN("response '{}' ({}) aborted after transfering '{}', {} out of {} bytes, reason: {} ({})", + ReasonStringForHttpResultCode(m_ResponseCode), + m_ResponseCode, + ToString(m_ContentType), + NumberOfBytesTransferred, + m_TotalDataSize, + GetSystemErrorAsString(IoResult), + IoResult); // if one transmit failed there's really no need to go on return nullptr; @@ -673,7 +691,7 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ); } - auto EmitReponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { + auto EmitResponseDetails = [&](StringBuilderBase& ResponseDetails) -> void { for (int i = 0; i < ThisRequestChunkCount; ++i) { const HTTP_DATA_CHUNK Chunk = m_HttpDataChunks[ThisRequestChunkOffset + i]; @@ -756,7 +774,7 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) // Emit diagnostics ExtendableStringBuilder<256> ResponseDetails; - EmitReponseDetails(ResponseDetails); + EmitResponseDetails(ResponseDetails); ZEN_WARN("failed to send HTTP response (error {}: '{}'), request URL: '{}', ({}.{}) response: {}", SendResult, @@ -817,7 +835,7 @@ HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) ZEN_TRACE_CPU("httpsys::AsyncWork::IssueRequest"); ErrorCode.clear(); - Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem, WorkerThreadPool::EMode::EnableBacklog); } HttpSysRequestHandler* @@ -836,6 +854,8 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr void HttpAsyncWorkRequest::AsyncWorkItem::Execute() { + ZEN_MEMSCOPE(GetHttpsysTag()); + ZEN_TRACE_CPU("httpsys::async_execute"); try @@ -873,10 +893,15 @@ HttpAsyncWorkRequest::AsyncWorkItem::Execute() new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); } } - catch (std::exception& Ex) + catch (const AssertException& AssertEx) { return (void)Tx.IssueNextRequest( - new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + new HttpMessageResponseRequest(Tx, 500, fmt::format("Assert thrown in async work: '{}", AssertEx.FullDescription()))); + } + catch (const std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: {}", Ex.what()))); } } @@ -896,6 +921,8 @@ HttpSysServer::HttpSysServer(const HttpSysConfig& InConfig) , m_IsAsyncResponseEnabled(InConfig.IsAsyncResponseEnabled) , m_InitialConfig(InConfig) { + ZEN_MEMSCOPE(GetHttpsysTag()); + // Initialize thread pool int MinThreadCount; @@ -971,6 +998,8 @@ HttpSysServer::Close() int HttpSysServer::InitializeServer(int BasePort) { + ZEN_MEMSCOPE(GetHttpsysTag()); + using namespace std::literals; WideStringBuilder<64> WildcardUrlPath; @@ -1215,6 +1244,8 @@ HttpSysServer::Cleanup() WorkerThreadPool& HttpSysServer::WorkPool() { + ZEN_MEMSCOPE(GetHttpsysTag()); + if (!m_AsyncWorkPool) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); @@ -1299,6 +1330,8 @@ HttpSysServer::IssueNewRequestMaybe() return; } + ZEN_MEMSCOPE(GetHttpsysTag()); + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); std::error_code ErrorCode; @@ -1322,6 +1355,8 @@ HttpSysServer::IssueNewRequestMaybe() void HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) { + ZEN_MEMSCOPE(GetHttpsysTag()); + if (UrlPath[0] == '/') { ++UrlPath; @@ -1483,11 +1518,15 @@ HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler return true; } - ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + ZEN_WARN("IssueRequest() failed: {}", ErrorCode.message()); + } + catch (const AssertException& AssertEx) + { + ZEN_ERROR("Assert thrown in IssueNextRequest(): {}", AssertEx.FullDescription()); } - catch (std::exception& Ex) + catch (const std::exception& Ex) { - ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + ZEN_ERROR("exception caught in IssueNextRequest(): {}", Ex.what()); } // something went wrong, no request is pending @@ -1659,7 +1698,7 @@ HttpSysServerRequest::ParseSessionId() const { if (Header.RawValueLength == Oid::StringLength) { - return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + return Oid::TryFromHexString({Header.pRawValue, Header.RawValueLength}); } } } @@ -1698,6 +1737,8 @@ HttpSysServerRequest::ReadPayload() void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) { + ZEN_MEMSCOPE(GetHttpsysTag()); + ZEN_ASSERT(IsHandled() == false); auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); @@ -1715,6 +1756,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) { + ZEN_MEMSCOPE(GetHttpsysTag()); + ZEN_ASSERT(IsHandled() == false); auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); @@ -1732,6 +1775,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy void HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { + ZEN_MEMSCOPE(GetHttpsysTag()); + ZEN_ASSERT(IsHandled() == false); auto Response = @@ -1750,6 +1795,8 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy void HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) { + ZEN_MEMSCOPE(GetHttpsysTag()); + if (m_HttpTx.Server().IsAsyncResponseEnabled()) { m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); @@ -1826,7 +1873,17 @@ InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) ErrorCode = MakeErrorCode(HttpApiResult); - ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + if (IsInitialRequest()) + { + ZEN_WARN("initial HttpReceiveHttpRequest failed, error: {}", ErrorCode.message()); + } + else + { + ZEN_WARN("HttpReceiveHttpRequest (offset: {}, content-length: {}) failed, error: {}", + m_CurrentPayloadOffset, + m_PayloadBuffer.GetSize(), + ErrorCode.message()); + } return; } @@ -1837,6 +1894,8 @@ InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) HttpSysRequestHandler* InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { + ZEN_MEMSCOPE(GetHttpsysTag()); + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); switch (IoResult) @@ -1985,23 +2044,28 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); } - catch (std::system_error& SystemError) + catch (const AssertException& AssertEx) + { + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, AssertEx.FullDescription()); + } + catch (const std::system_error& SystemError) { if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) { return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, SystemError.what()); } - ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + ZEN_WARN("Caught system error exception while handling request: {}. ({})", SystemError.what(), SystemError.code().value()); return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, SystemError.what()); } - catch (std::bad_alloc& BadAlloc) + catch (const std::bad_alloc& BadAlloc) { return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InsufficientStorage, BadAlloc.what()); } - catch (std::exception& ex) + catch (const std::exception& ex) { - ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + ZEN_WARN("Caught exception while handling request: '{}'", ex.what()); return new HttpMessageResponseRequest(Transaction(), (uint16_t)HttpResponseCode::InternalServerError, ex.what()); } } @@ -2014,6 +2078,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT int HttpSysServer::Initialize(int BasePort, std::filesystem::path DataDir) { + ZEN_TRACE_CPU("HttpSysServer::Initialize"); + ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { @@ -2042,6 +2108,9 @@ HttpSysServer::RegisterService(HttpService& Service) Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { + ZEN_TRACE_CPU("CreateHttpSysServer"); + ZEN_MEMSCOPE(GetHttpsysTag()); + return Ref<HttpServer>(new HttpSysServer(Config)); } diff --git a/src/zenhttp/servers/httptracer.cpp b/src/zenhttp/servers/httptracer.cpp new file mode 100644 index 000000000..483307fb1 --- /dev/null +++ b/src/zenhttp/servers/httptracer.cpp @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptracer.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/session.h> + +namespace zen { + +void +HttpServerTracer::Initialize(std::filesystem::path DataDir) +{ + m_DataDir = DataDir; + m_PayloadDir = DataDir / "debug" / GetSessionIdString(); + + ZEN_INFO("any debug payloads will be written to '{}'", m_PayloadDir); +} + +void +HttpServerTracer::WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload) +{ + uint64_t PayloadSize = 0; + std::vector<const IoBuffer*> Buffers; + for (auto& Io : Payload) + { + Buffers.push_back(&Io); + PayloadSize += Io.GetSize(); + } + + if (PayloadSize) + { + WriteFile(m_PayloadDir / Filename, Buffers.data(), Buffers.size()); + } +} + +} // namespace zen diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h new file mode 100644 index 000000000..da72c79c9 --- /dev/null +++ b/src/zenhttp/servers/httptracer.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpserver.h> + +#pragma once + +namespace zen { + +/** Helper class for HTTP server implementations + + Provides some common functionality which can be used across all server + implementations. These could be in the root class but I think it's nicer + to hide the implementation details from client code + */ +class HttpServerTracer +{ +public: + void Initialize(std::filesystem::path DataDir); + void WriteDebugPayload(std::string_view Filename, const std::span<const IoBuffer> Payload); + +private: + std::filesystem::path m_DataDir; // Application data directory + std::filesystem::path m_PayloadDir; // Request debugging payload directory +}; + +} // namespace zen diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp index a9a782821..96a15518c 100644 --- a/src/zenhttp/transports/asiotransport.cpp +++ b/src/zenhttp/transports/asiotransport.cpp @@ -426,7 +426,7 @@ AsioTransportPlugin::Initialize(TransportServer* ServerInterface) { m_IoService.run(); } - catch (std::exception& e) + catch (const std::exception& e) { ZEN_ERROR("exception caught in asio event loop: {}", e.what()); } diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp index e09e62ec5..fb3dd23b5 100644 --- a/src/zenhttp/transports/dlltransport.cpp +++ b/src/zenhttp/transports/dlltransport.cpp @@ -21,18 +21,31 @@ namespace zen { ////////////////////////////////////////////////////////////////////////// +class DllTransportLogger : public TransportLogger, public RefCounted +{ +public: + DllTransportLogger(std::string_view PluginName); + virtual ~DllTransportLogger() = default; + + void LogMessage(LogLevel Level, const char* Message) override; + +private: + std::string m_PluginName; +}; + struct LoadedDll { std::string Name; std::filesystem::path LoadedFromPath; + DllTransportLogger* Logger = nullptr; Ref<TransportPlugin> Plugin; }; class DllTransportPluginImpl : public DllTransportPlugin, RefCounted { public: - DllTransportPluginImpl(); - ~DllTransportPluginImpl(); + DllTransportPluginImpl() = default; + ~DllTransportPluginImpl() = default; virtual uint32_t AddRef() const override; virtual uint32_t Release() const override; @@ -42,7 +55,7 @@ public: virtual const char* GetDebugName() override; virtual bool IsAvailable() override; - virtual void LoadDll(std::string_view Name) override; + virtual bool LoadDll(std::string_view Name) override; virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) override; private: @@ -51,12 +64,27 @@ private: std::vector<LoadedDll> m_Transports; }; -DllTransportPluginImpl::DllTransportPluginImpl() +DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginName(PluginName) { } -DllTransportPluginImpl::~DllTransportPluginImpl() +void +DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { + logging::level::LogLevel Level; + // clang-format off + switch (PluginLogLevel) + { + case LogLevel::Trace: Level = logging::level::Trace; break; + case LogLevel::Debug: Level = logging::level::Debug; break; + case LogLevel::Info: Level = logging::level::Info; break; + case LogLevel::Warn: Level = logging::level::Warn; break; + case LogLevel::Err: Level = logging::level::Err; break; + case LogLevel::Critical: Level = logging::level::Critical; break; + default: Level = logging::level::Off; break; + } + // clang-format on + ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t @@ -109,6 +137,7 @@ DllTransportPluginImpl::Shutdown() try { Transport.Plugin->Shutdown(); + Transport.Logger->Release(); } catch (const std::exception&) { @@ -143,42 +172,73 @@ DllTransportPluginImpl::ConfigureDll(std::string_view Name, const char* OptionTa } } -void +bool DllTransportPluginImpl::LoadDll(std::string_view Name) { RwLock::ExclusiveLockScope _(m_Lock); - ExtendableStringBuilder<128> DllPath; - DllPath << Name << ".dll"; + ExtendableStringBuilder<1024> DllPath; + DllPath << Name; + if (!Name.ends_with(".dll")) + { + DllPath << ".dll"; + } + + std::string FileName = std::filesystem::path(DllPath.c_str()).filename().replace_extension().string(); + HMODULE DllHandle = LoadLibraryA(DllPath.c_str()); if (!DllHandle) { - std::error_code Ec = MakeErrorCodeFromLastError(); - - throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath)); + ZEN_WARN("Failed to load transport DLL from '{}' due to '{}'", DllPath, GetLastErrorAsString()) + return false; } - TransportPlugin* CreateTransportPlugin(); + PfnGetTransportPluginVersion GetVersion = (PfnGetTransportPluginVersion)GetProcAddress(DllHandle, "GetTransportPluginVersion"); + PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); + + uint32_t APIVersion = 0; + uint32_t PluginVersion = 0; + + if (GetVersion) + { + GetVersion(&APIVersion, &PluginVersion); + } - PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); + const bool bValidApiVersion = APIVersion == kTransportApiVersion; - if (!CreatePlugin) + if (!GetVersion || !CreatePlugin || !bValidApiVersion) { std::error_code Ec = MakeErrorCodeFromLastError(); FreeLibrary(DllHandle); - throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath)); + if (GetVersion && !bValidApiVersion) + { + ZEN_WARN("Failed to load transport DLL from '{}' due to invalid API version {}, supported API version is {}", + DllPath, + APIVersion, + kTransportApiVersion) + } + else + { + ZEN_WARN("Failed to load transport DLL from '{}' due to not finding GetTransportPluginVersion or CreateTransportPlugin", + DllPath) + } + + return false; } LoadedDll NewDll; NewDll.Name = Name; NewDll.LoadedFromPath = DllPath.c_str(); - NewDll.Plugin = CreatePlugin(); + NewDll.Logger = new DllTransportLogger(FileName); + NewDll.Logger->AddRef(); + NewDll.Plugin = CreatePlugin(NewDll.Logger); m_Transports.emplace_back(std::move(NewDll)); + return true; } DllTransportPlugin* diff --git a/src/zenhttp/transports/dlltransport.h b/src/zenhttp/transports/dlltransport.h index 9346a10ce..c49f888da 100644 --- a/src/zenhttp/transports/dlltransport.h +++ b/src/zenhttp/transports/dlltransport.h @@ -15,7 +15,7 @@ namespace zen { class DllTransportPlugin : public TransportPlugin { public: - virtual void LoadDll(std::string_view Name) = 0; + virtual bool LoadDll(std::string_view Name) = 0; virtual void ConfigureDll(std::string_view Name, const char* OptionTag, const char* OptionValue) = 0; }; diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp index 7407c55dd..c06a50c95 100644 --- a/src/zenhttp/transports/winsocktransport.cpp +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -304,18 +304,20 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface) TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; Connection->Initialize(ConnectionInterface, ClientSocket); - m_WorkerThreadpool->ScheduleWork([Connection] { - try - { - Connection->HandleConnection(); - } - catch (std::exception& Ex) - { - ZEN_WARN("exception caught in connection loop: {}", Ex.what()); - } - - delete Connection; - }); + m_WorkerThreadpool->ScheduleWork( + [Connection] { + try + { + Connection->HandleConnection(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + }, + WorkerThreadPool::EMode::EnableBacklog); } else { diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 8393f399b..b6ffbe467 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -7,7 +7,7 @@ target('zenhttp') add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) - add_deps("zencore", "zenutil", "transport-sdk") + add_deps("zencore", "transport-sdk") add_packages( "vcpkg::asio", "vcpkg::cpr", diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 6b855c4db..a2679f92e 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -6,7 +6,7 @@ # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> -# include <zenutil/packageformat.h> +# include <zenhttp/packageformat.h> namespace zen { @@ -15,6 +15,7 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + forcelink_packageformat(); } } // namespace zen |