diff options
| author | Stefan Boberg <[email protected]> | 2025-09-30 19:07:51 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-09-30 19:07:51 +0200 |
| commit | 634181a04efff90def7a98d98eac7078e1d4e62d (patch) | |
| tree | 04678bba636a76d21f300ff6e73af4473274cf12 /src/zenhttp/clients | |
| parent | use batching clang-format for quicker turnaround on validate actions (#529) (diff) | |
| download | zen-634181a04efff90def7a98d98eac7078e1d4e62d.tar.xz zen-634181a04efff90def7a98d98eac7078e1d4e62d.zip | |
HttpClient support for pluggable back-ends (#532)
refactored HttpClient to separate out cpr implementation into separate classes, with an abstract base class to allow plugging in multiple implementations in the future
Diffstat (limited to 'src/zenhttp/clients')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.cpp | 474 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 147 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 1035 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 151 |
4 files changed, 1807 insertions, 0 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp new file mode 100644 index 000000000..8e5136dff --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -0,0 +1,474 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcommon.h" + +#include <fmt/format.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory/memory.h> +#include <zencore/windows.h> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/basicfile.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <fcntl.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +namespace zen { + +using namespace std::literals; + +namespace detail { + + static std::atomic_uint32_t TempFileBaseIndex; + + TempPayloadFile::TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {} + TempPayloadFile::~TempPayloadFile() + { + ZEN_TRACE_CPU("TempPayloadFile::Close"); + try + { + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); + BOOL Success = CloseHandle(m_FileHandle); +#else + std::error_code Ec; + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); + if (Ec) + { + ZEN_WARN("Error reported on get file path from handle {} for temp payload unlink operation, reason '{}'", + m_FileHandle, + Ec.message()); + } + else + { + unlink(FilePath.c_str()); + } + int Fd = int(uintptr_t(m_FileHandle)); + bool Success = (close(Fd) == 0); +#endif + if (!Success) + { + ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); + } + + m_FileHandle = nullptr; + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what()); + } + } + + std::error_code TempPayloadFile::Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize) + { + ZEN_TRACE_CPU("TempPayloadFile::Open"); + ZEN_ASSERT(m_FileHandle == nullptr); + + std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) | + detail::TempFileBaseIndex.fetch_add(1); + + std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex); +#if ZEN_PLATFORM_WINDOWS + LPCWSTR lpFileName = FileName.c_str(); + const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE); + const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr; + const DWORD dwCreationDisposition = CREATE_ALWAYS; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + const HANDLE hTemplateFile = nullptr; + const HANDLE FileHandle = CreateFile(lpFileName, + dwDesiredAccess, + dwShareMode, + lpSecurityAttributes, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + return MakeErrorCodeFromLastError(); + } +#else // ZEN_PLATFORM_WINDOWS + int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC; + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + return MakeErrorCodeFromLastError(); + } + fchmod(Fd, 0666); + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif // ZEN_PLATFORM_WINDOWS + m_FileHandle = FileHandle; + + PrepareFileForScatteredWrite(m_FileHandle, FinalSize); + + return {}; + } + + std::error_code TempPayloadFile::Write(std::string_view DataString) + { + ZEN_TRACE_CPU("TempPayloadFile::Write"); + const uint8_t* DataPtr = (const uint8_t*)DataString.data(); + size_t DataSize = DataString.size(); + if (DataSize >= CacheBufferSize) + { + std::error_code Ec = Flush(); + if (Ec) + { + return Ec; + } + return AppendData(DataPtr, DataSize); + } + size_t CopySize = Min(DataSize, CacheBufferSize - m_CacheBufferOffset); + memcpy(&m_CacheBuffer[m_CacheBufferOffset], DataPtr, CopySize); + m_CacheBufferOffset += CopySize; + DataSize -= CopySize; + if (m_CacheBufferOffset == CacheBufferSize) + { + AppendData(m_CacheBuffer, CacheBufferSize); + if (DataSize > 0) + { + ZEN_ASSERT(DataSize < CacheBufferSize); + memcpy(m_CacheBuffer, DataPtr + CopySize, DataSize); + } + m_CacheBufferOffset = DataSize; + } + else + { + ZEN_ASSERT(DataSize == 0); + } + return {}; + } + + IoBuffer TempPayloadFile::DetachToIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::DetachToIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true); + Buffer.SetDeleteOnClose(true); + m_FileHandle = 0; + m_WriteOffset = 0; + return Buffer; + } + + IoBuffer TempPayloadFile::BorrowIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::BorrowIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, 0, m_WriteOffset); + return Buffer; + } + + void TempPayloadFile::ResetWritePos(uint64_t WriteOffset) + { + ZEN_TRACE_CPU("TempPayloadFile::ResetWritePos"); + Flush(); + m_WriteOffset = WriteOffset; + } + + std::error_code TempPayloadFile::Flush() + { + ZEN_TRACE_CPU("TempPayloadFile::Flush"); + if (m_CacheBufferOffset == 0) + { + return {}; + } + std::error_code Res = AppendData(m_CacheBuffer, m_CacheBufferOffset); + m_CacheBufferOffset = 0; + return Res; + } + + std::error_code TempPayloadFile::AppendData(const void* Data, uint64_t Size) + { + ZEN_TRACE_CPU("TempPayloadFile::AppendData"); + ZEN_ASSERT(m_FileHandle != nullptr); + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); + uint64_t NumberOfBytesWritten = 0; +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten); + } +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset); + bool Success = (BytesWritten > 0); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten); + } +#endif + + if (!Success) + { + return MakeErrorCodeFromLastError(); + } + + Size -= NumberOfBytesWritten; + m_WriteOffset += NumberOfBytesWritten; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten; + } + return {}; + } + + BufferedReadFileStream::BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize) + : m_FileHandle(FileHandle) + , m_FileSize(FileSize) + , m_FileEnd(FileOffset + FileSize) + , m_BufferSize(Min(BufferSize, FileSize)) + , m_FileOffset(FileOffset) + { + } + + BufferedReadFileStream::~BufferedReadFileStream() { Memory::Free(m_Buffer); } + void BufferedReadFileStream::Read(void* Data, uint64_t Size) + { + ZEN_ASSERT(Data != nullptr); + if (Size > m_BufferSize) + { + Read(Data, Size, m_FileOffset); + m_FileOffset += Size; + return; + } + uint8_t* WritePtr = ((uint8_t*)Data); + uint64_t Begin = m_FileOffset; + uint64_t End = m_FileOffset + Size; + ZEN_ASSERT(m_FileOffset >= m_BufferStart); + if (m_FileOffset < m_BufferEnd) + { + ZEN_ASSERT(m_Buffer != nullptr); + uint64_t Count = Min(m_BufferEnd, End) - m_FileOffset; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer + Begin - m_BufferStart, Count); + Begin += Count; + if (Begin == End) + { + m_FileOffset = End; + return; + } + } + if (End == m_FileEnd) + { + Read(WritePtr + Begin - m_FileOffset, End - Begin, Begin); + } + else + { + if (!m_Buffer) + { + m_BufferSize = Min(m_FileEnd - m_FileOffset, m_BufferSize); + m_Buffer = (uint8_t*)Memory::Alloc(gsl::narrow<size_t>(m_BufferSize)); + } + m_BufferStart = Begin; + m_BufferEnd = Min(Begin + m_BufferSize, m_FileEnd); + Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); + uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer, Count); + ZEN_ASSERT(Begin + Count == End); + } + m_FileOffset = End; + } + + void BufferedReadFileStream::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) + { + const uint64_t MaxChunkSize = 1u * 1024 * 1024; + std::error_code Ec; + ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec); + + if (Ec) + { + std::error_code DummyEc; + throw std::system_error( + Ec, + fmt::format("HttpClient::BufferedReadFileStream ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + FileOffset, + BytesToRead, + PathFromHandle(m_FileHandle, DummyEc).generic_string(), + m_FileSize)); + } + } + + CompositeBufferReadStream::CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize) + : m_Data(Data) + , m_BufferSize(BufferSize) + , m_SegmentIndex(0) + , m_BytesLeftInSegment(0) + { + } + uint64_t CompositeBufferReadStream::Read(void* Data, uint64_t Size) + { + uint64_t Result = 0; + uint8_t* WritePtr = (uint8_t*)Data; + while ((Size > 0) && (m_SegmentIndex < m_Data.GetSegments().size())) + { + if (m_BytesLeftInSegment == 0) + { + const SharedBuffer& Segment = m_Data.GetSegments()[m_SegmentIndex]; + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Segment.AsIoBuffer().GetFileReference(FileRef)) + { + m_SegmentDiskBuffer = std::make_unique<BufferedReadFileStream>(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + m_BufferSize); + } + else + { + m_SegmentMemoryBuffer = Segment.GetView(); + } + m_BytesLeftInSegment = Segment.GetSize(); + } + uint64_t BytesToRead = Min(m_BytesLeftInSegment, Size); + if (m_SegmentDiskBuffer) + { + m_SegmentDiskBuffer->Read(WritePtr, BytesToRead); + } + else + { + ZEN_ASSERT_SLOW(m_SegmentMemoryBuffer.GetSize() >= BytesToRead); + memcpy(WritePtr, m_SegmentMemoryBuffer.GetData(), BytesToRead); + m_SegmentMemoryBuffer.MidInline(BytesToRead); + } + WritePtr += BytesToRead; + Size -= BytesToRead; + Result += BytesToRead; + + m_BytesLeftInSegment -= BytesToRead; + if (m_BytesLeftInSegment == 0) + { + m_SegmentDiskBuffer.reset(); + m_SegmentMemoryBuffer.Reset(); + m_SegmentIndex++; + } + } + return Result; + } + +} // namespace detail + +} // namespace zen + +#if ZEN_WITH_TESTS +namespace zen { + +namespace testutil { + IoHash HashComposite(const CompositeBuffer& Payload) + { + IoHashStream Hasher; + const uint64_t PayloadSize = Payload.GetSize(); + std::vector<uint8_t> Buffer(64u * 1024u); + detail::CompositeBufferReadStream Stream(Payload, 137u * 1024u); + for (uint64_t Offset = 0; Offset < PayloadSize;) + { + uint64_t Count = Min(64u * 1024u, PayloadSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + }; + + IoHash HashFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize) + { + IoHashStream Hasher; + std::vector<uint8_t> Buffer(64u * 1024u); + detail::BufferedReadFileStream Stream(FileHandle, FileOffset, FileSize, 137u * 1024u); + for (uint64_t Offset = 0; Offset < FileSize;) + { + uint64_t Count = Min(64u * 1024u, FileSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + } + +} // namespace testutil + +TEST_CASE("BufferedReadFileStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer DiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer1"); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + CHECK(DiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer Partial(DiskBuffer, 37 * 1024, 512 * 1024); + CHECK(Partial.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(Partial), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer SmallDiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(63 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK(SmallDiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(SmallDiskBuffer), + testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); +} + +TEST_CASE("CompositeBufferReadStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer MemoryBuffer1 = CreateRandomBlob(64); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer1)))); + + IoBuffer MemoryBuffer2 = CreateRandomBlob(561 * 1024); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer2)))); + + IoBuffer DiskBuffer1 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(267 * 3 * 1024)), TmpDir.Path() / "diskbuffer1"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer1)))); + + IoBuffer DiskBuffer2 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(3 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer2)))); + + IoBuffer DiskBuffer3 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer3"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer3), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer3)))); + + CompositeBuffer Data(SharedBuffer(std::move(MemoryBuffer1)), + SharedBuffer(std::move(DiskBuffer1)), + SharedBuffer(std::move(DiskBuffer2)), + SharedBuffer(std::move(MemoryBuffer2)), + SharedBuffer(std::move(DiskBuffer3))); + CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h new file mode 100644 index 000000000..9060cde48 --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.h @@ -0,0 +1,147 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compositebuffer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpclient.h> + +namespace zen { + +using namespace std::literals; + +class HttpClientBase +{ +public: + HttpClientBase(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + virtual ~HttpClientBase() = 0; + + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}) = 0; + + LoggerRef Log() { return m_Log; } + std::string_view GetBaseUri() const { return m_BaseUri; } + std::string_view GetSessionId() const { return m_SessionId; } + + bool Authenticate(); + +protected: + LoggerRef m_Log; + std::string m_BaseUri; + std::string m_SessionId; + const HttpClientSettings m_ConnectionSettings; + + const std::optional<HttpClientAccessToken> GetAccessToken(); + + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +namespace detail { + + class TempPayloadFile + { + public: + TempPayloadFile(const TempPayloadFile&) = delete; + TempPayloadFile& operator=(const TempPayloadFile&) = delete; + + TempPayloadFile(); + ~TempPayloadFile(); + + std::error_code Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize); + std::error_code Write(std::string_view DataString); + IoBuffer DetachToIoBuffer(); + IoBuffer BorrowIoBuffer(); + inline uint64_t GetSize() const { return m_WriteOffset; } + void ResetWritePos(uint64_t WriteOffset); + + private: + std::error_code Flush(); + std::error_code AppendData(const void* Data, uint64_t Size); + + void* m_FileHandle; + std::uint64_t m_WriteOffset; + static constexpr uint64_t CacheBufferSize = 512u * 1024u; + uint8_t m_CacheBuffer[CacheBufferSize]; + std::uint64_t m_CacheBufferOffset = 0; + }; + + class BufferedReadFileStream + { + public: + BufferedReadFileStream(const BufferedReadFileStream&) = delete; + BufferedReadFileStream& operator=(const BufferedReadFileStream&) = delete; + + BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize); + ~BufferedReadFileStream(); + + void Read(void* Data, uint64_t Size); + + private: + void Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset); + + void* m_FileHandle = nullptr; + const uint64_t m_FileSize = 0; + const uint64_t m_FileEnd = 0; + uint64_t m_BufferSize = 0; + uint8_t* m_Buffer = nullptr; + uint64_t m_BufferStart = 0; + uint64_t m_BufferEnd = 0; + uint64_t m_FileOffset = 0; + }; + + class CompositeBufferReadStream + { + public: + CompositeBufferReadStream(const CompositeBufferReadStream&) = delete; + CompositeBufferReadStream& operator=(const CompositeBufferReadStream&) = delete; + + CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize); + uint64_t Read(void* Data, uint64_t Size); + + private: + const CompositeBuffer& m_Data; + const uint64_t m_BufferSize; + size_t m_SegmentIndex; + std::unique_ptr<BufferedReadFileStream> m_SegmentDiskBuffer; + MemoryView m_SegmentMemoryBuffer; + uint64_t m_BytesLeftInSegment; + }; + +} // namespace detail + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp new file mode 100644 index 000000000..568106887 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -0,0 +1,1035 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcpr.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/packageformat.h> + +namespace zen { + +HttpClientBase* +CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +{ + return new CprHttpClient(BaseUri, ConnectionSettings); +} + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +// If we want to support different HTTP client implementations then we'll need to make this more abstract + +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) + { + switch ((cpr::ErrorCode)m_Error) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + return ResponseClass::kHttpCantConnectError; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return ResponseClass::kHttpNoHost; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return ResponseClass::kHttpTimeout; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_CACERT_ERROR: + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// CPR helpers + +static cpr::Body +AsCprBody(const CbObject& Obj) +{ + return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); +} + +static cpr::Body +AsCprBody(const IoBuffer& Obj) +{ + return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); +} + +////////////////////////////////////////////////////////////////////////// + +static HttpClient::Response +ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) +{ + // This ends up doing a memcpy, would be good to get rid of it by streaming results + // into buffer directly + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); + + if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) + { + const HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; +} + +static HttpClient::Response +CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); + if (HttpResponse.error) + { + if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && + HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) + { + ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); + } + + // Client side failure code + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed, + .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), + .ErrorMessage = HttpResponse.error.message}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; + } + else + { + return ResponseWithPayload( + SessionId, + HttpResponse, + WorkResponseCode, + Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size())); + } +} + +static bool +ShouldRetry(const cpr::Response& Response) +{ + switch (Response.error.code) + { + case cpr::ErrorCode::OK: + break; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch ((HttpResponseCode)Response.status_code) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +}; + +static bool +ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + + if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second); + if (!ExpectedContentSize.has_value()) + { + Response.error = + cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Response.error = cpr::Error( + /*CURLE_READ_ERROR*/ 26, + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); + return false; + } + } + + if (Response.status_code == (long)HttpResponseCode::PartialContent) + { + return true; + } + + if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, + fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString())); + return false; + } + } + } + + if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) + { + if (ContentType->second == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize)) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); + return false; + } + } + if (ContentType->second == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); + return false; + } + } + } + + return true; +} + +static cpr::Response +DoWithRetry( + std::string_view SessionId, + std::function<cpr::Response()>&& Func, + uint8_t RetryCount, + std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static cpr::Response +DoWithRetry(std::string_view SessionId, + std::function<cpr::Response()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile, + uint8_t RetryCount) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) +: HttpClientBase(BaseUri, Connectionsettings) +{ +} + +CprHttpClient::~CprHttpClient() +{ + ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto CprSession : m_Sessions) + { + delete CprSession; + } + m_Sessions.clear(); + }); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::Session +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_TRACE_CPU("CprHttpClient::AllocSession"); + cpr::Session* CprSession = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + CprSession = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (CprSession == nullptr) + { + CprSession = new cpr::Session(); + CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); + CprSession->SetTimeout(ConnectionSettings.Timeout); + if (ConnectionSettings.AssumeHttp2) + { + CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); + } + } + + if (!AdditionalHeader->empty()) + { + CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); + } + if (!SessionId.empty()) + { + CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); + } + if (AccessToken) + { + CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + } + if (!Parameters->empty()) + { + cpr::Parameters Tmp; + for (auto It = Parameters->begin(); It != Parameters->end(); It++) + { + Tmp.Add({It->first, It->second}); + } + CprSession->SetParameters(Tmp); + } + else + { + CprSession->SetParameters({}); + } + + ExtendableStringBuilder<128> UrlBuffer; + UrlBuffer << BaseUrl << ResourcePath; + CprSession->SetUrl(UrlBuffer.c_str()); + + return Session(this, CprSession); +} + +void +CprHttpClient::ReleaseSession(cpr::Session* CprSession) +{ + ZEN_TRACE_CPU("CprHttpClient::ReleaseSession"); + CprSession->SetUrl({}); + CprSession->SetHeader({}); + CprSession->SetBody({}); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); +} + +CprHttpClient::Response +CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::TransactPackage"); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, + Url, + m_ConnectionSettings, + {{"Content-Length", "0"}}, + Parameters, + m_SessionId, + GetAccessToken()); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Get(); + }, + m_ConnectionSettings.RetryCount, + [](cpr::Response& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + })); +} + +CprHttpClient::Response +CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Head(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Delete(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { + size_t DelimiterPos = header.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string Key = header.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string Value = header.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + + return std::make_pair(Key, Value); + } + return std::make_pair(header, ""); + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + size_t RangeStartPos = RangeIt->second.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = + ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength = RequestedRangeEnd.value() - 1; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > 1024 * 1024) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile, + m_ConnectionSettings.RetryCount); + + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h new file mode 100644 index 000000000..ed9d10c27 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/cprutils.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/body.h> +#include <cpr/session.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CprHttpClient : public HttpClientBase +{ +public: + CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + ~CprHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct Session + { + Session(CprHttpClient* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {} + ~Session() { Outer->ReleaseSession(CprSession); } + + inline cpr::Session* operator->() const { return CprSession; } + inline cpr::Response Get() + { + ZEN_TRACE_CPU("HttpClient::Impl::Get"); + cpr::Response Result = CprSession->Get(); + ZEN_TRACE("GET {}", Result); + return Result; + } + inline cpr::Response Download(cpr::WriteCallback&& Write, std::optional<cpr::HeaderCallback>&& Header = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Download"); + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } + cpr::Response Result = CprSession->Download(Write); + ZEN_TRACE("GET {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); + return Result; + } + inline cpr::Response Head() + { + ZEN_TRACE_CPU("HttpClient::Impl::Head"); + cpr::Response Result = CprSession->Head(); + ZEN_TRACE("HEAD {}", Result); + return Result; + } + inline cpr::Response Put(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Put"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Put(); + ZEN_TRACE("PUT {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Post"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Post(); + ZEN_TRACE("POST {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Delete() + { + ZEN_TRACE_CPU("HttpClient::Impl::Delete"); + cpr::Response Result = CprSession->Delete(); + ZEN_TRACE("DELETE {}", Result); + return Result; + } + + LoggerRef Log() { return Outer->Log(); } + + private: + CprHttpClient* Outer; + cpr::Session* CprSession; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken); + + RwLock m_SessionLock; + std::vector<cpr::Session*> m_Sessions; + + void ReleaseSession(cpr::Session*); +}; + +} // namespace zen |