diff options
| author | Stefan Boberg <[email protected]> | 2025-09-30 19:07:51 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-09-30 19:07:51 +0200 |
| commit | 634181a04efff90def7a98d98eac7078e1d4e62d (patch) | |
| tree | 04678bba636a76d21f300ff6e73af4473274cf12 | |
| parent | use batching clang-format for quicker turnaround on validate actions (#529) (diff) | |
| download | zen-634181a04efff90def7a98d98eac7078e1d4e62d.tar.xz zen-634181a04efff90def7a98d98eac7078e1d4e62d.zip | |
HttpClient support for pluggable back-ends (#532)
refactored HttpClient to separate out cpr implementation into separate classes, with an abstract base class to allow plugging in multiple implementations in the future
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.cpp | 474 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 147 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 1035 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 151 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 1746 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 23 |
6 files changed, 1942 insertions, 1634 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp new file mode 100644 index 000000000..8e5136dff --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -0,0 +1,474 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcommon.h" + +#include <fmt/format.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory/memory.h> +#include <zencore/windows.h> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/basicfile.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <fcntl.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +namespace zen { + +using namespace std::literals; + +namespace detail { + + static std::atomic_uint32_t TempFileBaseIndex; + + TempPayloadFile::TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {} + TempPayloadFile::~TempPayloadFile() + { + ZEN_TRACE_CPU("TempPayloadFile::Close"); + try + { + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); + BOOL Success = CloseHandle(m_FileHandle); +#else + std::error_code Ec; + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); + if (Ec) + { + ZEN_WARN("Error reported on get file path from handle {} for temp payload unlink operation, reason '{}'", + m_FileHandle, + Ec.message()); + } + else + { + unlink(FilePath.c_str()); + } + int Fd = int(uintptr_t(m_FileHandle)); + bool Success = (close(Fd) == 0); +#endif + if (!Success) + { + ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); + } + + m_FileHandle = nullptr; + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what()); + } + } + + std::error_code TempPayloadFile::Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize) + { + ZEN_TRACE_CPU("TempPayloadFile::Open"); + ZEN_ASSERT(m_FileHandle == nullptr); + + std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) | + detail::TempFileBaseIndex.fetch_add(1); + + std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex); +#if ZEN_PLATFORM_WINDOWS + LPCWSTR lpFileName = FileName.c_str(); + const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE); + const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr; + const DWORD dwCreationDisposition = CREATE_ALWAYS; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + const HANDLE hTemplateFile = nullptr; + const HANDLE FileHandle = CreateFile(lpFileName, + dwDesiredAccess, + dwShareMode, + lpSecurityAttributes, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + return MakeErrorCodeFromLastError(); + } +#else // ZEN_PLATFORM_WINDOWS + int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC; + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + return MakeErrorCodeFromLastError(); + } + fchmod(Fd, 0666); + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif // ZEN_PLATFORM_WINDOWS + m_FileHandle = FileHandle; + + PrepareFileForScatteredWrite(m_FileHandle, FinalSize); + + return {}; + } + + std::error_code TempPayloadFile::Write(std::string_view DataString) + { + ZEN_TRACE_CPU("TempPayloadFile::Write"); + const uint8_t* DataPtr = (const uint8_t*)DataString.data(); + size_t DataSize = DataString.size(); + if (DataSize >= CacheBufferSize) + { + std::error_code Ec = Flush(); + if (Ec) + { + return Ec; + } + return AppendData(DataPtr, DataSize); + } + size_t CopySize = Min(DataSize, CacheBufferSize - m_CacheBufferOffset); + memcpy(&m_CacheBuffer[m_CacheBufferOffset], DataPtr, CopySize); + m_CacheBufferOffset += CopySize; + DataSize -= CopySize; + if (m_CacheBufferOffset == CacheBufferSize) + { + AppendData(m_CacheBuffer, CacheBufferSize); + if (DataSize > 0) + { + ZEN_ASSERT(DataSize < CacheBufferSize); + memcpy(m_CacheBuffer, DataPtr + CopySize, DataSize); + } + m_CacheBufferOffset = DataSize; + } + else + { + ZEN_ASSERT(DataSize == 0); + } + return {}; + } + + IoBuffer TempPayloadFile::DetachToIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::DetachToIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true); + Buffer.SetDeleteOnClose(true); + m_FileHandle = 0; + m_WriteOffset = 0; + return Buffer; + } + + IoBuffer TempPayloadFile::BorrowIoBuffer() + { + ZEN_TRACE_CPU("TempPayloadFile::BorrowIoBuffer"); + if (std::error_code Ec = Flush(); Ec) + { + ThrowSystemError(Ec.value(), Ec.message()); + } + ZEN_ASSERT(m_FileHandle != nullptr); + void* FileHandle = m_FileHandle; + IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, 0, m_WriteOffset); + return Buffer; + } + + void TempPayloadFile::ResetWritePos(uint64_t WriteOffset) + { + ZEN_TRACE_CPU("TempPayloadFile::ResetWritePos"); + Flush(); + m_WriteOffset = WriteOffset; + } + + std::error_code TempPayloadFile::Flush() + { + ZEN_TRACE_CPU("TempPayloadFile::Flush"); + if (m_CacheBufferOffset == 0) + { + return {}; + } + std::error_code Res = AppendData(m_CacheBuffer, m_CacheBufferOffset); + m_CacheBufferOffset = 0; + return Res; + } + + std::error_code TempPayloadFile::AppendData(const void* Data, uint64_t Size) + { + ZEN_TRACE_CPU("TempPayloadFile::AppendData"); + ZEN_ASSERT(m_FileHandle != nullptr); + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); + uint64_t NumberOfBytesWritten = 0; +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten); + } +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset); + bool Success = (BytesWritten > 0); + if (Success) + { + NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten); + } +#endif + + if (!Success) + { + return MakeErrorCodeFromLastError(); + } + + Size -= NumberOfBytesWritten; + m_WriteOffset += NumberOfBytesWritten; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten; + } + return {}; + } + + BufferedReadFileStream::BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize) + : m_FileHandle(FileHandle) + , m_FileSize(FileSize) + , m_FileEnd(FileOffset + FileSize) + , m_BufferSize(Min(BufferSize, FileSize)) + , m_FileOffset(FileOffset) + { + } + + BufferedReadFileStream::~BufferedReadFileStream() { Memory::Free(m_Buffer); } + void BufferedReadFileStream::Read(void* Data, uint64_t Size) + { + ZEN_ASSERT(Data != nullptr); + if (Size > m_BufferSize) + { + Read(Data, Size, m_FileOffset); + m_FileOffset += Size; + return; + } + uint8_t* WritePtr = ((uint8_t*)Data); + uint64_t Begin = m_FileOffset; + uint64_t End = m_FileOffset + Size; + ZEN_ASSERT(m_FileOffset >= m_BufferStart); + if (m_FileOffset < m_BufferEnd) + { + ZEN_ASSERT(m_Buffer != nullptr); + uint64_t Count = Min(m_BufferEnd, End) - m_FileOffset; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer + Begin - m_BufferStart, Count); + Begin += Count; + if (Begin == End) + { + m_FileOffset = End; + return; + } + } + if (End == m_FileEnd) + { + Read(WritePtr + Begin - m_FileOffset, End - Begin, Begin); + } + else + { + if (!m_Buffer) + { + m_BufferSize = Min(m_FileEnd - m_FileOffset, m_BufferSize); + m_Buffer = (uint8_t*)Memory::Alloc(gsl::narrow<size_t>(m_BufferSize)); + } + m_BufferStart = Begin; + m_BufferEnd = Min(Begin + m_BufferSize, m_FileEnd); + Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); + uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; + memcpy(WritePtr + Begin - m_FileOffset, m_Buffer, Count); + ZEN_ASSERT(Begin + Count == End); + } + m_FileOffset = End; + } + + void BufferedReadFileStream::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) + { + const uint64_t MaxChunkSize = 1u * 1024 * 1024; + std::error_code Ec; + ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec); + + if (Ec) + { + std::error_code DummyEc; + throw std::system_error( + Ec, + fmt::format("HttpClient::BufferedReadFileStream ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", + FileOffset, + BytesToRead, + PathFromHandle(m_FileHandle, DummyEc).generic_string(), + m_FileSize)); + } + } + + CompositeBufferReadStream::CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize) + : m_Data(Data) + , m_BufferSize(BufferSize) + , m_SegmentIndex(0) + , m_BytesLeftInSegment(0) + { + } + uint64_t CompositeBufferReadStream::Read(void* Data, uint64_t Size) + { + uint64_t Result = 0; + uint8_t* WritePtr = (uint8_t*)Data; + while ((Size > 0) && (m_SegmentIndex < m_Data.GetSegments().size())) + { + if (m_BytesLeftInSegment == 0) + { + const SharedBuffer& Segment = m_Data.GetSegments()[m_SegmentIndex]; + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Segment.AsIoBuffer().GetFileReference(FileRef)) + { + m_SegmentDiskBuffer = std::make_unique<BufferedReadFileStream>(FileRef.FileHandle, + FileRef.FileChunkOffset, + FileRef.FileChunkSize, + m_BufferSize); + } + else + { + m_SegmentMemoryBuffer = Segment.GetView(); + } + m_BytesLeftInSegment = Segment.GetSize(); + } + uint64_t BytesToRead = Min(m_BytesLeftInSegment, Size); + if (m_SegmentDiskBuffer) + { + m_SegmentDiskBuffer->Read(WritePtr, BytesToRead); + } + else + { + ZEN_ASSERT_SLOW(m_SegmentMemoryBuffer.GetSize() >= BytesToRead); + memcpy(WritePtr, m_SegmentMemoryBuffer.GetData(), BytesToRead); + m_SegmentMemoryBuffer.MidInline(BytesToRead); + } + WritePtr += BytesToRead; + Size -= BytesToRead; + Result += BytesToRead; + + m_BytesLeftInSegment -= BytesToRead; + if (m_BytesLeftInSegment == 0) + { + m_SegmentDiskBuffer.reset(); + m_SegmentMemoryBuffer.Reset(); + m_SegmentIndex++; + } + } + return Result; + } + +} // namespace detail + +} // namespace zen + +#if ZEN_WITH_TESTS +namespace zen { + +namespace testutil { + IoHash HashComposite(const CompositeBuffer& Payload) + { + IoHashStream Hasher; + const uint64_t PayloadSize = Payload.GetSize(); + std::vector<uint8_t> Buffer(64u * 1024u); + detail::CompositeBufferReadStream Stream(Payload, 137u * 1024u); + for (uint64_t Offset = 0; Offset < PayloadSize;) + { + uint64_t Count = Min(64u * 1024u, PayloadSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + }; + + IoHash HashFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize) + { + IoHashStream Hasher; + std::vector<uint8_t> Buffer(64u * 1024u); + detail::BufferedReadFileStream Stream(FileHandle, FileOffset, FileSize, 137u * 1024u); + for (uint64_t Offset = 0; Offset < FileSize;) + { + uint64_t Count = Min(64u * 1024u, FileSize - Offset); + Stream.Read(Buffer.data(), Count); + Hasher.Append(Buffer.data(), Count); + Offset += Count; + } + return Hasher.GetHash(); + } + +} // namespace testutil + +TEST_CASE("BufferedReadFileStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer DiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer1"); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + CHECK(DiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer Partial(DiskBuffer, 37 * 1024, 512 * 1024); + CHECK(Partial.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(Partial), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); + + IoBuffer SmallDiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(63 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK(SmallDiskBuffer.GetFileReference(FileRef)); + CHECK_EQ(IoHash::HashBuffer(SmallDiskBuffer), + testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); +} + +TEST_CASE("CompositeBufferReadStream") +{ + ScopedTemporaryDirectory TmpDir; + + IoBuffer MemoryBuffer1 = CreateRandomBlob(64); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer1)))); + + IoBuffer MemoryBuffer2 = CreateRandomBlob(561 * 1024); + CHECK_EQ(IoHash::HashBuffer(MemoryBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer2)))); + + IoBuffer DiskBuffer1 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(267 * 3 * 1024)), TmpDir.Path() / "diskbuffer1"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer1)))); + + IoBuffer DiskBuffer2 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(3 * 1024)), TmpDir.Path() / "diskbuffer2"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer2)))); + + IoBuffer DiskBuffer3 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer3"); + CHECK_EQ(IoHash::HashBuffer(DiskBuffer3), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer3)))); + + CompositeBuffer Data(SharedBuffer(std::move(MemoryBuffer1)), + SharedBuffer(std::move(DiskBuffer1)), + SharedBuffer(std::move(DiskBuffer2)), + SharedBuffer(std::move(MemoryBuffer2)), + SharedBuffer(std::move(DiskBuffer3))); + CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h new file mode 100644 index 000000000..9060cde48 --- /dev/null +++ b/src/zenhttp/clients/httpclientcommon.h @@ -0,0 +1,147 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compositebuffer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpclient.h> + +namespace zen { + +using namespace std::literals; + +class HttpClientBase +{ +public: + HttpClientBase(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + virtual ~HttpClientBase() = 0; + + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) = 0; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}) = 0; + + LoggerRef Log() { return m_Log; } + std::string_view GetBaseUri() const { return m_BaseUri; } + std::string_view GetSessionId() const { return m_SessionId; } + + bool Authenticate(); + +protected: + LoggerRef m_Log; + std::string m_BaseUri; + std::string m_SessionId; + const HttpClientSettings m_ConnectionSettings; + + const std::optional<HttpClientAccessToken> GetAccessToken(); + + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +namespace detail { + + class TempPayloadFile + { + public: + TempPayloadFile(const TempPayloadFile&) = delete; + TempPayloadFile& operator=(const TempPayloadFile&) = delete; + + TempPayloadFile(); + ~TempPayloadFile(); + + std::error_code Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize); + std::error_code Write(std::string_view DataString); + IoBuffer DetachToIoBuffer(); + IoBuffer BorrowIoBuffer(); + inline uint64_t GetSize() const { return m_WriteOffset; } + void ResetWritePos(uint64_t WriteOffset); + + private: + std::error_code Flush(); + std::error_code AppendData(const void* Data, uint64_t Size); + + void* m_FileHandle; + std::uint64_t m_WriteOffset; + static constexpr uint64_t CacheBufferSize = 512u * 1024u; + uint8_t m_CacheBuffer[CacheBufferSize]; + std::uint64_t m_CacheBufferOffset = 0; + }; + + class BufferedReadFileStream + { + public: + BufferedReadFileStream(const BufferedReadFileStream&) = delete; + BufferedReadFileStream& operator=(const BufferedReadFileStream&) = delete; + + BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize); + ~BufferedReadFileStream(); + + void Read(void* Data, uint64_t Size); + + private: + void Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset); + + void* m_FileHandle = nullptr; + const uint64_t m_FileSize = 0; + const uint64_t m_FileEnd = 0; + uint64_t m_BufferSize = 0; + uint8_t* m_Buffer = nullptr; + uint64_t m_BufferStart = 0; + uint64_t m_BufferEnd = 0; + uint64_t m_FileOffset = 0; + }; + + class CompositeBufferReadStream + { + public: + CompositeBufferReadStream(const CompositeBufferReadStream&) = delete; + CompositeBufferReadStream& operator=(const CompositeBufferReadStream&) = delete; + + CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize); + uint64_t Read(void* Data, uint64_t Size); + + private: + const CompositeBuffer& m_Data; + const uint64_t m_BufferSize; + size_t m_SegmentIndex; + std::unique_ptr<BufferedReadFileStream> m_SegmentDiskBuffer; + MemoryView m_SegmentMemoryBuffer; + uint64_t m_BytesLeftInSegment; + }; + +} // namespace detail + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp new file mode 100644 index 000000000..568106887 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -0,0 +1,1035 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcpr.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/packageformat.h> + +namespace zen { + +HttpClientBase* +CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +{ + return new CprHttpClient(BaseUri, ConnectionSettings); +} + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +// If we want to support different HTTP client implementations then we'll need to make this more abstract + +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) + { + switch ((cpr::ErrorCode)m_Error) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + return ResponseClass::kHttpCantConnectError; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return ResponseClass::kHttpNoHost; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return ResponseClass::kHttpTimeout; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_CACERT_ERROR: + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// CPR helpers + +static cpr::Body +AsCprBody(const CbObject& Obj) +{ + return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); +} + +static cpr::Body +AsCprBody(const IoBuffer& Obj) +{ + return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); +} + +////////////////////////////////////////////////////////////////////////// + +static HttpClient::Response +ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) +{ + // This ends up doing a memcpy, would be good to get rid of it by streaming results + // into buffer directly + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); + + if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) + { + const HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; +} + +static HttpClient::Response +CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); + if (HttpResponse.error) + { + if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && + HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) + { + ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); + } + + // Client side failure code + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed, + .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), + .ErrorMessage = HttpResponse.error.message}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), + .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), + .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), + .ElapsedSeconds = HttpResponse.elapsed}; + } + else + { + return ResponseWithPayload( + SessionId, + HttpResponse, + WorkResponseCode, + Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size())); + } +} + +static bool +ShouldRetry(const cpr::Response& Response) +{ + switch (Response.error.code) + { + case cpr::ErrorCode::OK: + break; + case cpr::ErrorCode::INTERNAL_ERROR: + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch ((HttpResponseCode)Response.status_code) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +}; + +static bool +ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + + if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second); + if (!ExpectedContentSize.has_value()) + { + Response.error = + cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Response.error = cpr::Error( + /*CURLE_READ_ERROR*/ 26, + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); + return false; + } + } + + if (Response.status_code == (long)HttpResponseCode::PartialContent) + { + return true; + } + + if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, + fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString())); + return false; + } + } + } + + if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) + { + if (ContentType->second == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize)) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); + return false; + } + } + if (ContentType->second == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); + return false; + } + } + } + + return true; +} + +static cpr::Response +DoWithRetry( + std::string_view SessionId, + std::function<cpr::Response()>&& Func, + uint8_t RetryCount, + std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static cpr::Response +DoWithRetry(std::string_view SessionId, + std::function<cpr::Response()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile, + uint8_t RetryCount) +{ + uint8_t Attempt = 0; + cpr::Response Result = Func(); + while (Attempt < RetryCount) + { + if (!ShouldRetry(Result)) + { + if (Result.error || !IsHttpSuccessCode(Result.status_code)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); + Result = Func(); + } + return Result; +} + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) +: HttpClientBase(BaseUri, Connectionsettings) +{ +} + +CprHttpClient::~CprHttpClient() +{ + ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto CprSession : m_Sessions) + { + delete CprSession; + } + m_Sessions.clear(); + }); +} + +////////////////////////////////////////////////////////////////////////// + +CprHttpClient::Session +CprHttpClient::AllocSession(const std::string_view BaseUrl, + const std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_TRACE_CPU("CprHttpClient::AllocSession"); + cpr::Session* CprSession = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + CprSession = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (CprSession == nullptr) + { + CprSession = new cpr::Session(); + CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); + CprSession->SetTimeout(ConnectionSettings.Timeout); + if (ConnectionSettings.AssumeHttp2) + { + CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); + } + } + + if (!AdditionalHeader->empty()) + { + CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); + } + if (!SessionId.empty()) + { + CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); + } + if (AccessToken) + { + CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); + } + if (!Parameters->empty()) + { + cpr::Parameters Tmp; + for (auto It = Parameters->begin(); It != Parameters->end(); It++) + { + Tmp.Add({It->first, It->second}); + } + CprSession->SetParameters(Tmp); + } + else + { + CprSession->SetParameters({}); + } + + ExtendableStringBuilder<128> UrlBuffer; + UrlBuffer << BaseUrl << ResourcePath; + CprSession->SetUrl(UrlBuffer.c_str()); + + return Session(this, CprSession); +} + +void +CprHttpClient::ReleaseSession(cpr::Session* CprSession) +{ + ZEN_TRACE_CPU("CprHttpClient::ReleaseSession"); + CprSession->SetUrl({}); + CprSession->SetHeader({}); + CprSession->SetBody({}); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); +} + +CprHttpClient::Response +CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::TransactPackage"); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); + Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Put"); + + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, + Url, + m_ConnectionSettings, + {{"Content-Length", "0"}}, + Parameters, + m_SessionId, + GetAccessToken()); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Get(); + }, + m_ConnectionSettings.RetryCount, + [](cpr::Response& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + })); +} + +CprHttpClient::Response +CprHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Head(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + return Sess.Delete(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CprHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + return Sess.Post(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CprHttpClient::Response +CprHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + uint64_t Offset = 0; + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + Buffer.Read(buffer, size); + Offset += size; + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + } + Sess->SetBody(AsCprBody(Payload)); + return Sess.Put(); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Sess->UpdateHeader({HeaderContentType(ContentType)}); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { + size = Reader.Read(buffer, size); + return true; + }; + return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + }, + m_ConnectionSettings.RetryCount)); +} + +CprHttpClient::Response +CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CprHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { + size_t DelimiterPos = header.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string Key = header.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string Value = header.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + + return std::make_pair(Key, Value); + } + return std::make_pair(header, ""); + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + size_t RangeStartPos = RangeIt->second.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = + ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength = RequestedRangeEnd.value() - 1; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > 1024 * 1024) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + std::pair<std::string, std::string> Header = GetHeader(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile, + m_ConnectionSettings.RetryCount); + + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h new file mode 100644 index 000000000..ed9d10c27 --- /dev/null +++ b/src/zenhttp/clients/httpclientcpr.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/cprutils.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/body.h> +#include <cpr/session.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CprHttpClient : public HttpClientBase +{ +public: + CprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings = {}); + ~CprHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct Session + { + Session(CprHttpClient* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {} + ~Session() { Outer->ReleaseSession(CprSession); } + + inline cpr::Session* operator->() const { return CprSession; } + inline cpr::Response Get() + { + ZEN_TRACE_CPU("HttpClient::Impl::Get"); + cpr::Response Result = CprSession->Get(); + ZEN_TRACE("GET {}", Result); + return Result; + } + inline cpr::Response Download(cpr::WriteCallback&& Write, std::optional<cpr::HeaderCallback>&& Header = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Download"); + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } + cpr::Response Result = CprSession->Download(Write); + ZEN_TRACE("GET {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); + return Result; + } + inline cpr::Response Head() + { + ZEN_TRACE_CPU("HttpClient::Impl::Head"); + cpr::Response Result = CprSession->Head(); + ZEN_TRACE("HEAD {}", Result); + return Result; + } + inline cpr::Response Put(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Put"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Put(); + ZEN_TRACE("PUT {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + { + ZEN_TRACE_CPU("HttpClient::Impl::Post"); + if (Read) + { + CprSession->SetReadCallback(std::move(Read.value())); + } + cpr::Response Result = CprSession->Post(); + ZEN_TRACE("POST {}", Result); + CprSession->SetReadCallback({}); + return Result; + } + inline cpr::Response Delete() + { + ZEN_TRACE_CPU("HttpClient::Impl::Delete"); + cpr::Response Result = CprSession->Delete(); + ZEN_TRACE("DELETE {}", Result); + return Result; + } + + LoggerRef Log() { return Outer->Log(); } + + private: + CprHttpClient* Outer; + cpr::Session* CprSession; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(const std::string_view BaseUrl, + const std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + const std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken); + + RwLock m_SessionLock; + std::vector<cpr::Session*> m_Sessions; + + void ReleaseSession(cpr::Session*); +}; + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 5981d449a..3da9f91fc 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -1,12 +1,10 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenhttp/cprutils.h> #include <zenhttp/formatters.h> #include <zenhttp/httpclient.h> #include <zenhttp/httpserver.h> #include <zenhttp/packageformat.h> -#include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compositebuffer.h> @@ -21,903 +19,37 @@ #include <zencore/string.h> #include <zencore/trace.h> +#include "clients/httpclientcommon.h" + #if ZEN_WITH_TESTS -# include <zencore/basicfile.h> # include <zencore/testing.h> # include <zencore/testutils.h> #endif // ZEN_WITH_TESTS -ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/body.h> -#include <cpr/session.h> -ZEN_THIRD_PARTY_INCLUDES_END - -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -# include <fcntl.h> -# include <sys/stat.h> -# include <unistd.h> -#endif - -static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; - namespace zen { -using namespace std::literals; - -namespace detail { - - static std::atomic_uint32_t TempFileBaseIndex; - - class TempPayloadFile - { - public: - TempPayloadFile(const TempPayloadFile&) = delete; - TempPayloadFile& operator=(const TempPayloadFile&) = delete; - - TempPayloadFile() : m_FileHandle(nullptr), m_WriteOffset(0) {} - ~TempPayloadFile() - { - ZEN_TRACE_CPU("TempPayloadFile::Close"); - try - { - if (m_FileHandle) - { -#if ZEN_PLATFORM_WINDOWS - // Mark file for deletion when final handle is closed - FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; - - SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); - BOOL Success = CloseHandle(m_FileHandle); -#else - std::error_code Ec; - std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle, Ec); - if (Ec) - { - ZEN_WARN("Error reported on get file path from handle {} for temp payload unlink operation, reason '{}'", - m_FileHandle, - Ec.message()); - } - else - { - unlink(FilePath.c_str()); - } - int Fd = int(uintptr_t(m_FileHandle)); - bool Success = (close(Fd) == 0); -#endif - if (!Success) - { - ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); - } - - m_FileHandle = nullptr; - } - } - catch (const std::exception& Ex) - { - ZEN_ERROR("Failed deleting temp file {}. Reason '{}'", m_FileHandle, Ex.what()); - } - } - - std::error_code Open(const std::filesystem::path& TempFolderPath, uint64_t FinalSize) - { - ZEN_TRACE_CPU("TempPayloadFile::Open"); - ZEN_ASSERT(m_FileHandle == nullptr); - - std::uint64_t TmpIndex = ((std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) & 0xffffffffu) << 32) | - detail::TempFileBaseIndex.fetch_add(1); - - std::filesystem::path FileName = TempFolderPath / fmt::to_string(TmpIndex); -#if ZEN_PLATFORM_WINDOWS - LPCWSTR lpFileName = FileName.c_str(); - const DWORD dwDesiredAccess = (GENERIC_READ | GENERIC_WRITE | DELETE); - const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; - LPSECURITY_ATTRIBUTES lpSecurityAttributes = nullptr; - const DWORD dwCreationDisposition = CREATE_ALWAYS; - const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; - const HANDLE hTemplateFile = nullptr; - const HANDLE FileHandle = CreateFile(lpFileName, - dwDesiredAccess, - dwShareMode, - lpSecurityAttributes, - dwCreationDisposition, - dwFlagsAndAttributes, - hTemplateFile); - - if (FileHandle == INVALID_HANDLE_VALUE) - { - return MakeErrorCodeFromLastError(); - } -#else // ZEN_PLATFORM_WINDOWS - int OpenFlags = O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC; - int Fd = open(FileName.c_str(), OpenFlags, 0666); - if (Fd < 0) - { - return MakeErrorCodeFromLastError(); - } - fchmod(Fd, 0666); - - void* FileHandle = (void*)(uintptr_t(Fd)); -#endif // ZEN_PLATFORM_WINDOWS - m_FileHandle = FileHandle; - - PrepareFileForScatteredWrite(m_FileHandle, FinalSize); - - return {}; - } - - std::error_code Write(std::string_view DataString) - { - ZEN_TRACE_CPU("TempPayloadFile::Write"); - const uint8_t* DataPtr = (const uint8_t*)DataString.data(); - size_t DataSize = DataString.size(); - if (DataSize >= CacheBufferSize) - { - std::error_code Ec = Flush(); - if (Ec) - { - return Ec; - } - return AppendData(DataPtr, DataSize); - } - size_t CopySize = Min(DataSize, CacheBufferSize - m_CacheBufferOffset); - memcpy(&m_CacheBuffer[m_CacheBufferOffset], DataPtr, CopySize); - m_CacheBufferOffset += CopySize; - DataSize -= CopySize; - if (m_CacheBufferOffset == CacheBufferSize) - { - AppendData(m_CacheBuffer, CacheBufferSize); - if (DataSize > 0) - { - ZEN_ASSERT(DataSize < CacheBufferSize); - memcpy(m_CacheBuffer, DataPtr + CopySize, DataSize); - } - m_CacheBufferOffset = DataSize; - } - else - { - ZEN_ASSERT(DataSize == 0); - } - return {}; - } - - IoBuffer DetachToIoBuffer() - { - ZEN_TRACE_CPU("TempPayloadFile::DetachToIoBuffer"); - if (std::error_code Ec = Flush(); Ec) - { - ThrowSystemError(Ec.value(), Ec.message()); - } - ZEN_ASSERT(m_FileHandle != nullptr); - void* FileHandle = m_FileHandle; - IoBuffer Buffer(IoBuffer::File, FileHandle, 0, m_WriteOffset, /*IsWholeFile*/ true); - Buffer.SetDeleteOnClose(true); - m_FileHandle = 0; - m_WriteOffset = 0; - return Buffer; - } - - IoBuffer BorrowIoBuffer() - { - ZEN_TRACE_CPU("TempPayloadFile::BorrowIoBuffer"); - if (std::error_code Ec = Flush(); Ec) - { - ThrowSystemError(Ec.value(), Ec.message()); - } - ZEN_ASSERT(m_FileHandle != nullptr); - void* FileHandle = m_FileHandle; - IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, 0, m_WriteOffset); - return Buffer; - } - - uint64_t GetSize() const { return m_WriteOffset; } - void ResetWritePos(uint64_t WriteOffset) - { - ZEN_TRACE_CPU("TempPayloadFile::ResetWritePos"); - Flush(); - m_WriteOffset = WriteOffset; - } - - private: - std::error_code Flush() - { - ZEN_TRACE_CPU("TempPayloadFile::Flush"); - if (m_CacheBufferOffset == 0) - { - return {}; - } - std::error_code Res = AppendData(m_CacheBuffer, m_CacheBufferOffset); - m_CacheBufferOffset = 0; - return Res; - } - - std::error_code AppendData(const void* Data, uint64_t Size) - { - ZEN_TRACE_CPU("TempPayloadFile::AppendData"); - ZEN_ASSERT(m_FileHandle != nullptr); - const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; - - while (Size) - { - const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); - uint64_t NumberOfBytesWritten = 0; -#if ZEN_PLATFORM_WINDOWS - OVERLAPPED Ovl{}; - - Ovl.Offset = DWORD(m_WriteOffset & 0xffff'ffffu); - Ovl.OffsetHigh = DWORD(m_WriteOffset >> 32); - - DWORD dwNumberOfBytesWritten = 0; - - BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); - if (Success) - { - NumberOfBytesWritten = static_cast<uint64_t>(dwNumberOfBytesWritten); - } -#else - static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); - int Fd = int(uintptr_t(m_FileHandle)); - int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, m_WriteOffset); - bool Success = (BytesWritten > 0); - if (Success) - { - NumberOfBytesWritten = static_cast<uint64_t>(BytesWritten); - } -#endif - - if (!Success) - { - return MakeErrorCodeFromLastError(); - } - - Size -= NumberOfBytesWritten; - m_WriteOffset += NumberOfBytesWritten; - Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesWritten; - } - return {}; - } - - void* m_FileHandle; - std::uint64_t m_WriteOffset; - static constexpr uint64_t CacheBufferSize = 512u * 1024u; - uint8_t m_CacheBuffer[CacheBufferSize]; - std::uint64_t m_CacheBufferOffset = 0; - }; - - class BufferedReadFileStream - { - public: - BufferedReadFileStream(const BufferedReadFileStream&) = delete; - BufferedReadFileStream& operator=(const BufferedReadFileStream&) = delete; - - BufferedReadFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize, uint64_t BufferSize) - : m_FileHandle(FileHandle) - , m_FileSize(FileSize) - , m_FileEnd(FileOffset + FileSize) - , m_BufferSize(Min(BufferSize, FileSize)) - , m_FileOffset(FileOffset) - { - } - - ~BufferedReadFileStream() { Memory::Free(m_Buffer); } - void Read(void* Data, uint64_t Size) - { - ZEN_ASSERT(Data != nullptr); - if (Size > m_BufferSize) - { - Read(Data, Size, m_FileOffset); - m_FileOffset += Size; - return; - } - uint8_t* WritePtr = ((uint8_t*)Data); - uint64_t Begin = m_FileOffset; - uint64_t End = m_FileOffset + Size; - ZEN_ASSERT(m_FileOffset >= m_BufferStart); - if (m_FileOffset < m_BufferEnd) - { - ZEN_ASSERT(m_Buffer != nullptr); - uint64_t Count = Min(m_BufferEnd, End) - m_FileOffset; - memcpy(WritePtr + Begin - m_FileOffset, m_Buffer + Begin - m_BufferStart, Count); - Begin += Count; - if (Begin == End) - { - m_FileOffset = End; - return; - } - } - if (End == m_FileEnd) - { - Read(WritePtr + Begin - m_FileOffset, End - Begin, Begin); - } - else - { - if (!m_Buffer) - { - m_BufferSize = Min(m_FileEnd - m_FileOffset, m_BufferSize); - m_Buffer = (uint8_t*)Memory::Alloc(gsl::narrow<size_t>(m_BufferSize)); - } - m_BufferStart = Begin; - m_BufferEnd = Min(Begin + m_BufferSize, m_FileEnd); - Read(m_Buffer, m_BufferEnd - m_BufferStart, m_BufferStart); - uint64_t Count = Min(m_BufferEnd, End) - m_BufferStart; - memcpy(WritePtr + Begin - m_FileOffset, m_Buffer, Count); - ZEN_ASSERT(Begin + Count == End); - } - m_FileOffset = End; - } - - private: - void Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) - { - const uint64_t MaxChunkSize = 1u * 1024 * 1024; - std::error_code Ec; - ReadFile(m_FileHandle, Data, BytesToRead, FileOffset, MaxChunkSize, Ec); - - if (Ec) - { - std::error_code DummyEc; - throw std::system_error( - Ec, - fmt::format( - "HttpClient::BufferedReadFileStream ReadFile/pread failed (offset {:#x}, size {:#x}) file: '{}' (size {:#x})", - FileOffset, - BytesToRead, - PathFromHandle(m_FileHandle, DummyEc).generic_string(), - m_FileSize)); - } - } - - void* m_FileHandle = nullptr; - const uint64_t m_FileSize = 0; - const uint64_t m_FileEnd = 0; - uint64_t m_BufferSize = 0; - uint8_t* m_Buffer = nullptr; - uint64_t m_BufferStart = 0; - uint64_t m_BufferEnd = 0; - uint64_t m_FileOffset = 0; - }; - - class CompositeBufferReadStream - { - public: - CompositeBufferReadStream(const CompositeBuffer& Data, uint64_t BufferSize) - : m_Data(Data) - , m_BufferSize(BufferSize) - , m_SegmentIndex(0) - , m_BytesLeftInSegment(0) - { - } - uint64_t Read(void* Data, uint64_t Size) - { - uint64_t Result = 0; - uint8_t* WritePtr = (uint8_t*)Data; - while ((Size > 0) && (m_SegmentIndex < m_Data.GetSegments().size())) - { - if (m_BytesLeftInSegment == 0) - { - const SharedBuffer& Segment = m_Data.GetSegments()[m_SegmentIndex]; - IoBufferFileReference FileRef = {nullptr, 0, 0}; - if (Segment.AsIoBuffer().GetFileReference(FileRef)) - { - m_SegmentDiskBuffer = std::make_unique<BufferedReadFileStream>(FileRef.FileHandle, - FileRef.FileChunkOffset, - FileRef.FileChunkSize, - m_BufferSize); - } - else - { - m_SegmentMemoryBuffer = Segment.GetView(); - } - m_BytesLeftInSegment = Segment.GetSize(); - } - uint64_t BytesToRead = Min(m_BytesLeftInSegment, Size); - if (m_SegmentDiskBuffer) - { - m_SegmentDiskBuffer->Read(WritePtr, BytesToRead); - } - else - { - ZEN_ASSERT_SLOW(m_SegmentMemoryBuffer.GetSize() >= BytesToRead); - memcpy(WritePtr, m_SegmentMemoryBuffer.GetData(), BytesToRead); - m_SegmentMemoryBuffer.MidInline(BytesToRead); - } - WritePtr += BytesToRead; - Size -= BytesToRead; - Result += BytesToRead; - - m_BytesLeftInSegment -= BytesToRead; - if (m_BytesLeftInSegment == 0) - { - m_SegmentDiskBuffer.reset(); - m_SegmentMemoryBuffer.Reset(); - m_SegmentIndex++; - } - } - return Result; - } - - private: - const CompositeBuffer& m_Data; - const uint64_t m_BufferSize; - size_t m_SegmentIndex; - std::unique_ptr<BufferedReadFileStream> m_SegmentDiskBuffer; - MemoryView m_SegmentMemoryBuffer; - uint64_t m_BytesLeftInSegment; - }; - -} // namespace detail - -////////////////////////////////////////////////////////////////////////// -// -// CPR helpers - -static cpr::Body -AsCprBody(const CbObject& Obj) -{ - return cpr::Body((const char*)Obj.GetBuffer().GetData(), Obj.GetBuffer().GetSize()); -} - -static cpr::Body -AsCprBody(const IoBuffer& Obj) -{ - return cpr::Body((const char*)Obj.GetData(), Obj.GetSize()); -} - -////////////////////////////////////////////////////////////////////////// - -static HttpClient::Response -ResponseWithPayload(std::string_view SessionId, cpr::Response& HttpResponse, const HttpResponseCode WorkResponseCode, IoBuffer&& Payload) -{ - // This ends up doing a memcpy, would be good to get rid of it by streaming results - // into buffer directly - IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, HttpResponse.text.data(), HttpResponse.text.size()); - - if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) - { - const HttpContentType ContentType = ParseContentType(It->second); - - ResponseBuffer.SetContentType(ContentType); - } - - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) - { - ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); - } - - return HttpClient::Response{.StatusCode = WorkResponseCode, - .ResponsePayload = std::move(ResponseBuffer), - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; -} - -static HttpClient::Response -CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload = {}) -{ - const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); - if (HttpResponse.error) - { - if (HttpResponse.error.code != cpr::ErrorCode::OPERATION_TIMEDOUT && - HttpResponse.error.code != cpr::ErrorCode::CONNECTION_FAILURE && HttpResponse.error.code != cpr::ErrorCode::REQUEST_CANCELLED) - { - ZEN_WARN("HttpClient client failure (session: {}): {}", SessionId, HttpResponse); - } - - // Client side failure code - return HttpClient::Response{ - .StatusCode = WorkResponseCode, - .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size()), - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; - } - - if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) - { - return HttpClient::Response{.StatusCode = WorkResponseCode, - .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), - .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), - .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; - } - else - { - return ResponseWithPayload( - SessionId, - HttpResponse, - WorkResponseCode, - Payload ? std::move(Payload) : IoBufferBuilder::MakeCloneFromMemory(HttpResponse.text.data(), HttpResponse.text.size())); - } -} - -static bool -ShouldRetry(const cpr::Response& Response) -{ - switch (Response.error.code) - { - case cpr::ErrorCode::OK: - break; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return true; - default: - return false; - } - switch ((HttpResponseCode)Response.status_code) - { - case HttpResponseCode::RequestTimeout: - case HttpResponseCode::TooManyRequests: - case HttpResponseCode::InternalServerError: - case HttpResponseCode::BadGateway: - case HttpResponseCode::ServiceUnavailable: - case HttpResponseCode::GatewayTimeout: - return true; - default: - return false; - } -}; - -static bool -ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) -{ - ZEN_TRACE_CPU("ValidatePayload"); - IoBuffer ResponseBuffer = (Response.text.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() - : IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()); - - if (auto ContentLength = Response.header.find("Content-Length"); ContentLength != Response.header.end()) - { - std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLength->second); - if (!ExpectedContentSize.has_value()) - { - Response.error = - cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLength->second)); - return false; - } - if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) - { - Response.error = cpr::Error( - /*CURLE_READ_ERROR*/ 26, - fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLength->second)); - return false; - } - } - - if (Response.status_code == (long)HttpResponseCode::PartialContent) - { - return true; - } - - if (auto JupiterHash = Response.header.find("X-Jupiter-IoHash"); JupiterHash != Response.header.end()) - { - IoHash ExpectedPayloadHash; - if (IoHash::TryParse(JupiterHash->second, ExpectedPayloadHash)) - { - IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); - if (PayloadHash != ExpectedPayloadHash) - { - Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, - fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", - PayloadHash.ToHexString(), - ExpectedPayloadHash.ToHexString())); - return false; - } - } - } - - if (auto ContentType = Response.header.find("Content-Type"); ContentType != Response.header.end()) - { - if (ContentType->second == "application/x-ue-comp") - { - IoHash RawHash; - uint64_t RawSize; - if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, RawHash, RawSize)) - { - return true; - } - else - { - Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, "Compressed binary failed validation"); - return false; - } - } - if (ContentType->second == "application/x-ue-cb") - { - if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); - Error == CbValidateError::None) - { - return true; - } - else - { - Response.error = cpr::Error(/*CURLE_READ_ERROR*/ 26, fmt::format("Compact binary failed validation: {}", ToString(Error))); - return false; - } - } - } - - return true; -} - -static cpr::Response -DoWithRetry( - std::string_view SessionId, - std::function<cpr::Response()>&& Func, - uint8_t RetryCount, - std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }) -{ - uint8_t Attempt = 0; - cpr::Response Result = Func(); - while (Attempt < RetryCount) - { - if (!ShouldRetry(Result)) - { - if (Result.error || !IsHttpSuccessCode(Result.status_code)) - { - break; - } - if (Validate(Result)) - { - break; - } - } - Sleep(100 * (Attempt + 1)); - Attempt++; - ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); - Result = Func(); - } - return Result; -} - -static cpr::Response -DoWithRetry(std::string_view SessionId, - std::function<cpr::Response()>&& Func, - std::unique_ptr<detail::TempPayloadFile>& PayloadFile, - uint8_t RetryCount) -{ - uint8_t Attempt = 0; - cpr::Response Result = Func(); - while (Attempt < RetryCount) - { - if (!ShouldRetry(Result)) - { - if (Result.error || !IsHttpSuccessCode(Result.status_code)) - { - break; - } - if (ValidatePayload(Result, PayloadFile)) - { - break; - } - } - Sleep(100 * (Attempt + 1)); - Attempt++; - ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result)).ErrorMessage("Retry"), Attempt, RetryCount + 1); - Result = Func(); - } - return Result; -} - -static std::pair<std::string, std::string> -HeaderContentType(ZenContentType ContentType) -{ - return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); -} - -////////////////////////////////////////////////////////////////////////// - -struct HttpClient::Impl : public RefCounted -{ - Impl(LoggerRef Log); - ~Impl(); - - // Session allocation - - struct Session - { - Session(Impl* InOuter, cpr::Session* InSession) : Outer(InOuter), CprSession(InSession) {} - ~Session() { Outer->ReleaseSession(CprSession); } - - inline cpr::Session* operator->() const { return CprSession; } - inline cpr::Response Get() - { - ZEN_TRACE_CPU("HttpClient::Impl::Get"); - cpr::Response Result = CprSession->Get(); - ZEN_TRACE("GET {}", Result); - return Result; - } - inline cpr::Response Download(cpr::WriteCallback&& Write, std::optional<cpr::HeaderCallback>&& Header = {}) - { - ZEN_TRACE_CPU("HttpClient::Impl::Download"); - if (Header) - { - CprSession->SetHeaderCallback(std::move(Header.value())); - } - cpr::Response Result = CprSession->Download(Write); - ZEN_TRACE("GET {}", Result); - CprSession->SetHeaderCallback({}); - CprSession->SetWriteCallback({}); - return Result; - } - inline cpr::Response Head() - { - ZEN_TRACE_CPU("HttpClient::Impl::Head"); - cpr::Response Result = CprSession->Head(); - ZEN_TRACE("HEAD {}", Result); - return Result; - } - inline cpr::Response Put(std::optional<cpr::ReadCallback>&& Read = {}) - { - ZEN_TRACE_CPU("HttpClient::Impl::Put"); - if (Read) - { - CprSession->SetReadCallback(std::move(Read.value())); - } - cpr::Response Result = CprSession->Put(); - ZEN_TRACE("PUT {}", Result); - CprSession->SetReadCallback({}); - return Result; - } - inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) - { - ZEN_TRACE_CPU("HttpClient::Impl::Post"); - if (Read) - { - CprSession->SetReadCallback(std::move(Read.value())); - } - cpr::Response Result = CprSession->Post(); - ZEN_TRACE("POST {}", Result); - CprSession->SetReadCallback({}); - return Result; - } - inline cpr::Response Delete() - { - ZEN_TRACE_CPU("HttpClient::Impl::Delete"); - cpr::Response Result = CprSession->Delete(); - ZEN_TRACE("DELETE {}", Result); - return Result; - } - - LoggerRef Logger() { return Outer->Logger(); } - - private: - Impl* Outer; - cpr::Session* CprSession; - - Session(Session&&) = delete; - Session& operator=(Session&&) = delete; - }; - - Session AllocSession(const std::string_view BaseUrl, - const std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); - - LoggerRef Logger() { return m_Log; } - -private: - LoggerRef m_Log; - RwLock m_SessionLock; - std::vector<cpr::Session*> m_Sessions; - - void ReleaseSession(cpr::Session*); -}; - -HttpClient::Impl::Impl(LoggerRef Log) : m_Log(Log) -{ -} +extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings); -HttpClient::Impl::~Impl() -{ - ZEN_TRACE_CPU("HttpClient::Impl::~Impl"); - m_SessionLock.WithExclusiveLock([&] { - for (auto CprSession : m_Sessions) - { - delete CprSession; - } - m_Sessions.clear(); - }); -} - -HttpClient::Impl::Session -HttpClient::Impl::AllocSession(const std::string_view BaseUrl, - const std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - const std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) -{ - ZEN_TRACE_CPU("HttpClient::Impl::AllocSession"); - cpr::Session* CprSession = nullptr; - m_SessionLock.WithExclusiveLock([&] { - if (!m_Sessions.empty()) - { - CprSession = m_Sessions.back(); - m_Sessions.pop_back(); - } - }); - - if (CprSession == nullptr) - { - CprSession = new cpr::Session(); - CprSession->SetConnectTimeout(ConnectionSettings.ConnectTimeout); - CprSession->SetTimeout(ConnectionSettings.Timeout); - if (ConnectionSettings.AssumeHttp2) - { - CprSession->SetHttpVersion(cpr::HttpVersion{cpr::HttpVersionCode::VERSION_2_0_PRIOR_KNOWLEDGE}); - } - } - - if (!AdditionalHeader->empty()) - { - CprSession->SetHeader(cpr::Header(AdditionalHeader->begin(), AdditionalHeader->end())); - } - if (!SessionId.empty()) - { - CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); - } - if (AccessToken) - { - CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); - } - if (!Parameters->empty()) - { - cpr::Parameters Tmp; - for (auto It = Parameters->begin(); It != Parameters->end(); It++) - { - Tmp.Add({It->first, It->second}); - } - CprSession->SetParameters(Tmp); - } - else - { - CprSession->SetParameters({}); - } - - ExtendableStringBuilder<128> UrlBuffer; - UrlBuffer << BaseUrl << ResourcePath; - CprSession->SetUrl(UrlBuffer.c_str()); - - return Session(this, CprSession); -} - -void -HttpClient::Impl::ReleaseSession(cpr::Session* CprSession) -{ - ZEN_TRACE_CPU("HttpClient::Impl::ReleaseSession"); - CprSession->SetUrl({}); - CprSession->SetHeader({}); - CprSession->SetBody({}); - m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(CprSession); }); -} +using namespace std::literals; ////////////////////////////////////////////////////////////////////////// -HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& Connectionsettings) -: m_Log(zen::logging::Get(Connectionsettings.LogCategory)) +HttpClientBase::HttpClientBase(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +: m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) , m_BaseUri(BaseUri) -, m_ConnectionSettings(Connectionsettings) -, m_Impl(new Impl(m_Log)) +, m_ConnectionSettings(ConnectionSettings) { m_SessionId = GetSessionIdString(); } -HttpClient::~HttpClient() +HttpClientBase::~HttpClientBase() { } bool -HttpClient::Authenticate() +HttpClientBase::Authenticate() { - ZEN_TRACE_CPU("HttpClient::Authenticate"); + ZEN_TRACE_CPU("HttpClientBase::Authenticate"); std::optional<HttpClientAccessToken> Token = GetAccessToken(); if (!Token) { @@ -927,9 +59,9 @@ HttpClient::Authenticate() } const std::optional<HttpClientAccessToken> -HttpClient::GetAccessToken() +HttpClientBase::GetAccessToken() { - ZEN_TRACE_CPU("HttpClient::GetAccessToken"); + ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); if (!m_ConnectionSettings.AccessTokenProvider.has_value()) { return {}; @@ -950,607 +82,6 @@ HttpClient::GetAccessToken() return m_CachedAccessToken; } -HttpClient::Response -HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::TransactPackage"); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - - // First, list of offered chunks for filtering on the server end - - std::vector<IoHash> AttachmentsToSend; - std::span<const CbAttachment> Attachments = Package.GetAttachments(); - - const uint32_t RequestId = ++HttpClientRequestIdCounter; - auto RequestIdString = fmt::to_string(RequestId); - - if (Attachments.empty() == false) - { - CbObjectWriter Writer; - Writer.BeginArray("offer"); - - for (const CbAttachment& Attachment : Attachments) - { - Writer.AddHash(Attachment.GetHash()); - } - - Writer.EndArray(); - - BinaryWriter MemWriter; - Writer.Save(MemWriter); - - Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackageOffer), {"UE-Request", RequestIdString}}); - Sess->SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (FilterResponse.status_code == 200) - { - IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); - CbValidateError ValidationError = CbValidateError::None; - if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); - ValidationError == CbValidateError::None) - { - for (CbFieldView& Entry : ResponseObject["need"]) - { - ZEN_ASSERT(Entry.IsHash()); - AttachmentsToSend.push_back(Entry.AsHash()); - } - } - } - } - - // Prepare package for send - - CbPackage SendPackage; - SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); - - for (const IoHash& AttachmentCid : AttachmentsToSend) - { - const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); - - if (Attachment) - { - SendPackage.AddAttachment(*Attachment); - } - else - { - // This should be an error -- server asked to have something we can't find - } - } - - // Transmit package payload - - CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); - SharedBuffer FlatMessage = Message.Flatten(); - - Sess->UpdateHeader({HeaderContentType(HttpContentType::kCbPackage), {"UE-Request", RequestIdString}}); - Sess->SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (!IsHttpSuccessCode(FilterResponse.status_code)) - { - return {.StatusCode = HttpResponseCode(FilterResponse.status_code)}; - } - - IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); - - if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) - { - HttpContentType ContentType = ParseContentType(It->second); - - ResponseBuffer.SetContentType(ContentType); - } - - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; -} - -////////////////////////////////////////////////////////////////////////// -// -// Standard HTTP verbs -// - -HttpClient::Response -HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Put"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); - return Sess.Put(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) -{ - ZEN_TRACE_CPU("HttpClient::Put"); - - return CommonResponse(m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, - Url, - m_ConnectionSettings, - {{"Content-Length", "0"}}, - Parameters, - m_SessionId, - GetAccessToken()); - return Sess.Put(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) -{ - ZEN_TRACE_CPU("HttpClient::Get"); - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - return Sess.Get(); - }, - m_ConnectionSettings.RetryCount, - [](cpr::Response& Result) { - std::unique_ptr<detail::TempPayloadFile> NoTempFile; - return ValidatePayload(Result, NoTempFile); - })); -} - -HttpClient::Response -HttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Head"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - return Sess.Head(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Delete"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - return Sess.Delete(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) -{ - ZEN_TRACE_CPU("HttpClient::PostNoPayload"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - return Sess.Post(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) -{ - return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::PostWithPayload"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Sess->UpdateHeader({HeaderContentType(ContentType)}); - - IoBufferFileReference FileRef = {nullptr, 0, 0}; - if (Payload.GetFileReference(FileRef)) - { - uint64_t Offset = 0; - detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); - auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { - size = Min<size_t>(size, Payload.GetSize() - Offset); - Buffer.Read(buffer, size); - Offset += size; - return true; - }; - return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - } - Sess->SetBody(AsCprBody(Payload)); - return Sess.Post(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::PostObjectPayload"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); - return Sess.Post(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) -{ - return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); -} - -HttpClient::Response -HttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Post"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Sess->UpdateHeader({HeaderContentType(ContentType)}); - - detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); - auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { - size = Reader.Read(buffer, size); - return true; - }; - return Sess.Post(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Upload"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Sess->UpdateHeader({HeaderContentType(Payload.GetContentType())}); - - IoBufferFileReference FileRef = {nullptr, 0, 0}; - if (Payload.GetFileReference(FileRef)) - { - uint64_t Offset = 0; - detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); - auto ReadCallback = [&Payload, &Offset, &Buffer](char* buffer, size_t& size, intptr_t) { - size = Min<size_t>(size, Payload.GetSize() - Offset); - Buffer.Read(buffer, size); - Offset += size; - return true; - }; - return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - } - Sess->SetBody(AsCprBody(Payload)); - return Sess.Put(); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Upload"); - - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() { - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Sess->UpdateHeader({HeaderContentType(ContentType)}); - - detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); - auto ReadCallback = [&Reader](char* buffer, size_t& size, intptr_t) { - size = Reader.Read(buffer, size); - return true; - }; - return Sess.Put(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); - }, - m_ConnectionSettings.RetryCount)); -} - -HttpClient::Response -HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) -{ - ZEN_TRACE_CPU("HttpClient::Download"); - - std::string PayloadString; - std::unique_ptr<detail::TempPayloadFile> PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional<size_t> RequestedRangeStart = - ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > 1024 * 1024) - { - PayloadFile = std::make_unique<detail::TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Impl::Session Sess = - m_Impl->AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Impl::Session Sess = m_Impl->AllocSession(m_BaseUri, - Url, - m_ConnectionSettings, - HeadersWithRange, - {}, - m_SessionId, - GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile, - m_ConnectionSettings.RetryCount); - - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); -} - ////////////////////////////////////////////////////////////////////////// CbObject @@ -1662,107 +193,125 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix) ////////////////////////////////////////////////////////////////////////// -HttpClientError::ResponseClass -HttpClientError::GetResponseClass() const +HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings) +: m_BaseUri(BaseUri) +, m_ConnectionSettings(ConnectionSettings) { - if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) - { - switch ((cpr::ErrorCode)m_Error) - { - case cpr::ErrorCode::CONNECTION_FAILURE: - return ResponseClass::kHttpCantConnectError; - case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: - case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return ResponseClass::kHttpNoHost; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return ResponseClass::kHttpTimeout; - case cpr::ErrorCode::SSL_CONNECT_ERROR: - case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_CACERT_ERROR: - case cpr::ErrorCode::GENERIC_SSL_ERROR: - return ResponseClass::kHttpSLLError; - default: - return ResponseClass::kHttpOtherClientError; - } - } - else if (IsHttpSuccessCode(m_ResponseCode)) - { - return ResponseClass::kSuccess; - } - else - { - switch (m_ResponseCode) - { - case HttpResponseCode::Unauthorized: - return ResponseClass::kHttpUnauthorized; - case HttpResponseCode::NotFound: - return ResponseClass::kHttpNotFound; - case HttpResponseCode::Forbidden: - return ResponseClass::kHttpForbidden; - case HttpResponseCode::Conflict: - return ResponseClass::kHttpConflict; - case HttpResponseCode::InternalServerError: - return ResponseClass::kHttpInternalServerError; - case HttpResponseCode::ServiceUnavailable: - return ResponseClass::kHttpServiceUnavailable; - case HttpResponseCode::BadGateway: - return ResponseClass::kHttpBadGateway; - case HttpResponseCode::GatewayTimeout: - return ResponseClass::kHttpGatewayTimeout; - default: - if (m_ResponseCode >= HttpResponseCode::InternalServerError) - { - return ResponseClass::kHttpOtherServerError; - } - else - { - return ResponseClass::kHttpOtherClientError; - } - } - } + m_SessionId = GetSessionIdString(); + + m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings); } -////////////////////////////////////////////////////////////////////////// +HttpClient::~HttpClient() +{ + delete m_Inner; +} -#if ZEN_WITH_TESTS +HttpClient::Response +HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Put(Url, Payload, AdditionalHeader); +} -namespace testutil { - IoHash HashComposite(const CompositeBuffer& Payload) - { - IoHashStream Hasher; - const uint64_t PayloadSize = Payload.GetSize(); - std::vector<uint8_t> Buffer(64u * 1024u); - detail::CompositeBufferReadStream Stream(Payload, 137u * 1024u); - for (uint64_t Offset = 0; Offset < PayloadSize;) - { - uint64_t Count = Min(64u * 1024u, PayloadSize - Offset); - Stream.Read(Buffer.data(), Count); - Hasher.Append(Buffer.data(), Count); - Offset += Count; - } - return Hasher.GetHash(); - }; +HttpClient::Response +HttpClient::Put(std::string_view Url, const HttpClient::KeyValueMap& Parameters) +{ + return m_Inner->Put(Url, Parameters); +} - IoHash HashFileStream(void* FileHandle, uint64_t FileOffset, uint64_t FileSize) - { - IoHashStream Hasher; - std::vector<uint8_t> Buffer(64u * 1024u); - detail::BufferedReadFileStream Stream(FileHandle, FileOffset, FileSize, 137u * 1024u); - for (uint64_t Offset = 0; Offset < FileSize;) - { - uint64_t Count = Min(64u * 1024u, FileSize - Offset); - Stream.Read(Buffer.data(), Count); - Hasher.Append(Buffer.data(), Count); - Offset += Count; - } - return Hasher.GetHash(); - } +HttpClient::Response +HttpClient::Get(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) +{ + return m_Inner->Get(Url, AdditionalHeader, Parameters); +} + +HttpClient::Response +HttpClient::Head(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Head(Url, AdditionalHeader); +} -} // namespace testutil +HttpClient::Response +HttpClient::Delete(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Delete(Url, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) +{ + return m_Inner->Post(Url, AdditionalHeader, Parameters); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Post(Url, Payload, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Post(Url, Payload, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, CbPackage Payload, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Post(Url, Payload, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Upload(Url, Payload, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Upload(Url, Payload, ContentType, AdditionalHeader); +} + +HttpClient::Response +HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->Download(Url, TempFolderPath, AdditionalHeader); +} + +HttpClient::Response +HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const HttpClient::KeyValueMap& AdditionalHeader) +{ + return m_Inner->TransactPackage(Url, Package, AdditionalHeader); +} + +bool +HttpClient::Authenticate() +{ + return m_Inner->Authenticate(); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS TEST_CASE("responseformat") { @@ -1810,53 +359,6 @@ TEST_CASE("responseformat") } } -TEST_CASE("BufferedReadFileStream") -{ - ScopedTemporaryDirectory TmpDir; - - IoBuffer DiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer1"); - - IoBufferFileReference FileRef = {nullptr, 0, 0}; - CHECK(DiskBuffer.GetFileReference(FileRef)); - CHECK_EQ(IoHash::HashBuffer(DiskBuffer), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); - - IoBuffer Partial(DiskBuffer, 37 * 1024, 512 * 1024); - CHECK(Partial.GetFileReference(FileRef)); - CHECK_EQ(IoHash::HashBuffer(Partial), testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); - - IoBuffer SmallDiskBuffer = WriteToTempFile(CompositeBuffer(CreateRandomBlob(63 * 1024)), TmpDir.Path() / "diskbuffer2"); - CHECK(SmallDiskBuffer.GetFileReference(FileRef)); - CHECK_EQ(IoHash::HashBuffer(SmallDiskBuffer), - testutil::HashFileStream(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize)); -} - -TEST_CASE("CompositeBufferReadStream") -{ - ScopedTemporaryDirectory TmpDir; - - IoBuffer MemoryBuffer1 = CreateRandomBlob(64); - CHECK_EQ(IoHash::HashBuffer(MemoryBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer1)))); - - IoBuffer MemoryBuffer2 = CreateRandomBlob(561 * 1024); - CHECK_EQ(IoHash::HashBuffer(MemoryBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(MemoryBuffer2)))); - - IoBuffer DiskBuffer1 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(267 * 3 * 1024)), TmpDir.Path() / "diskbuffer1"); - CHECK_EQ(IoHash::HashBuffer(DiskBuffer1), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer1)))); - - IoBuffer DiskBuffer2 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(3 * 1024)), TmpDir.Path() / "diskbuffer2"); - CHECK_EQ(IoHash::HashBuffer(DiskBuffer2), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer2)))); - - IoBuffer DiskBuffer3 = WriteToTempFile(CompositeBuffer(CreateRandomBlob(496 * 5 * 1024)), TmpDir.Path() / "diskbuffer3"); - CHECK_EQ(IoHash::HashBuffer(DiskBuffer3), testutil::HashComposite(CompositeBuffer(SharedBuffer(DiskBuffer3)))); - - CompositeBuffer Data(SharedBuffer(std::move(MemoryBuffer1)), - SharedBuffer(std::move(DiskBuffer1)), - SharedBuffer(std::move(DiskBuffer2)), - SharedBuffer(std::move(MemoryBuffer2)), - SharedBuffer(std::move(DiskBuffer3))); - CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); -} - TEST_CASE("httpclient") { using namespace std::literals; diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index ec06aa229..aae7b94e5 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -108,6 +108,8 @@ private: const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; }; +class HttpClientBase; + class HttpClient { public: @@ -191,6 +193,11 @@ public: std::string ErrorMessage(std::string_view Prefix) const; }; + static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType) + { + return std::make_pair("Accept", MapContentTypeToString(ContentType)); + } + [[nodiscard]] Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Put(std::string_view Url, const KeyValueMap& Parameters = {}); [[nodiscard]] Response Get(std::string_view Url, const KeyValueMap& AdditionalHeader = {}, const KeyValueMap& Parameters = {}); @@ -220,27 +227,19 @@ public: [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader = {}); - static std::pair<std::string_view, std::string_view> Accept(ZenContentType ContentType) - { - return std::make_pair("Accept", MapContentTypeToString(ContentType)); - } - - LoggerRef Logger() { return m_Log; } + LoggerRef Log() { return m_Log; } std::string_view GetBaseUri() const { return m_BaseUri; } - bool Authenticate(); std::string_view GetSessionId() const { return m_SessionId; } + bool Authenticate(); + private: - const std::optional<HttpClientAccessToken> GetAccessToken(); - struct Impl; + HttpClientBase* m_Inner; LoggerRef m_Log; std::string m_BaseUri; std::string m_SessionId; const HttpClientSettings m_ConnectionSettings; - RwLock m_AccessTokenLock; - HttpClientAccessToken m_CachedAccessToken; - Ref<Impl> m_Impl; }; void httpclient_forcelink(); // internal |