diff options
Diffstat (limited to 'src/zenhttp/clients/httpclientcurl.cpp')
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 1947 |
1 files changed, 1947 insertions, 0 deletions
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp new file mode 100644 index 000000000..3cb749018 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -0,0 +1,1947 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcurl.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zenhttp/packageformat.h> +#include <algorithm> + +namespace zen { + +HttpClientBase* +CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) +{ + return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); +} + +static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; + +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback helpers + +struct WriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<WriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct HeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + // Trim trailing \r\n + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + // Trim whitespace + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct ReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<ReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +struct StreamReadCallbackData +{ + detail::CompositeBufferReadStream* Reader = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<StreamReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + return Data->Reader->Read(Buffer, MaxRead); +} + +struct FileReadCallbackData +{ + detail::BufferedReadFileStream* Buffer = nullptr; + uint64_t TotalSize = 0; + uint64_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<FileReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->TotalSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + Data->Buffer->Read(Buffer, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +static int +CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr) +{ + ZEN_UNUSED(Handle); + LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); + auto Log = [&]() -> LoggerRef { return LogRef; }; + + std::string_view DataView(Data, Size); + + // Trim trailing newlines + while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n')) + { + DataView.remove_suffix(1); + } + + switch (Type) + { + case CURLINFO_TEXT: + if (DataView.find("need more data"sv) == std::string_view::npos) + { + ZEN_INFO("TEXT: {}", DataView); + } + break; + case CURLINFO_HEADER_IN: + ZEN_INFO("HIN : {}", DataView); + break; + case CURLINFO_HEADER_OUT: + if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos) + { + std::string Copy(DataView); + auto BearerStart = TokenPos + 22; + auto BearerEnd = Copy.find_first_of("\r\n", BearerStart); + if (BearerEnd == std::string::npos) + { + BearerEnd = Copy.length(); + } + Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart)); + ZEN_INFO("HOUT: {}", Copy); + } + else + { + ZEN_INFO("HOUT: {}", DataView); + } + break; + default: + break; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +static curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<HttpClientAccessToken>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + std::string SessionHeader = fmt::format("UE-Session: {}", SessionId); + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken) + { + std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +static std::string +BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) +{ + std::string Url; + Url.reserve(BaseUrl.size() + ResourcePath.size() + 64); + Url.append(BaseUrl); + Url.append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size())); + char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size())); + Url += Separator; + Url += EncodedKey; + Url += '='; + Url += EncodedValue; + curl_free(EncodedKey); + curl_free(EncodedValue); + Separator = '&'; + } + } + + return Url; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::CurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction) +: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)) +{ +} + +CurlHttpClient::~CurlHttpClient() +{ + ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto* Handle : m_Sessions) + { + curl_easy_cleanup(Handle); + } + m_Sessions.clear(); + }); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::Perform() +{ + CurlResult Result; + + char ErrorBuffer[CURL_ERROR_SIZE] = {}; + curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer); + + Result.ErrorCode = curl_easy_perform(Handle); + + if (Result.ErrorCode != CURLE_OK) + { + Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode); + } + + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + Result.ElapsedSeconds = Elapsed; + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + Result.UploadedBytes = static_cast<int64_t>(UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + + return Result; +} + +bool +CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + +HttpClient::Response +CurlHttpClient::ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); + + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + const HttpContentType ContentType = ParseContentType(Value); + ResponseBuffer.SetContentType(ContentType); + break; + } + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + if (ShouldLogErrorCode(WorkResponseCode)) + { + ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}", + SessionId, + static_cast<int>(WorkResponseCode), + m_BaseUri); + } + } + + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Ranges = std::move(BoundaryPositions)}; +} + +HttpClient::Response +CurlHttpClient::CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode); + if (Result.ErrorCode != CURLE_OK) + { + const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); + if (!Quiet) + { + if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && + Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + SessionId, + static_cast<int>(Result.ErrorCode), + Result.ErrorMessage); + } + } + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) + { + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + else + { + return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); + } +} + +bool +CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + + IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + + // Find Content-Length in headers + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Length") + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value); + return false; + } + break; + } + } + + if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent)) + { + return true; + } + + // Check X-Jupiter-IoHash + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "X-Jupiter-IoHash") + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(Value, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; + } + } + break; + } + } + + // Validate content-type specific payload + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + if (Value == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (Value == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; + } + } + break; + } + } + + return true; +} + +bool +CurlHttpClient::ShouldRetry(const CurlResult& Result) +{ + switch (Result.ErrorCode) + { + case CURLE_OK: + break; + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(Result.StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Session +CurlHttpClient::AllocSession(std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader); + ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); + CURL* Handle = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + Handle = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (Handle == nullptr) + { + Handle = curl_easy_init(); + } + else + { + curl_easy_reset(Handle); + } + + // Unix domain socket + if (!ConnectionSettings.UnixSocketPath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str()); + } + + // Build URL with parameters + std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Timeouts + if (ConnectionSettings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count())); + } + if (ConnectionSettings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count())); + } + + // HTTP/2 + if (ConnectionSettings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // Verbose/debug + if (ConnectionSettings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); + curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log); + } + + // SSL options + if (ConnectionSettings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str()); + } + + // Disable signal handling for thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (ConnectionSettings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + + // Note: Headers are NOT set here. Each method builds its own header list + // (potentially adding method-specific headers like Content-Type) and is + // responsible for freeing it with curl_slist_free_all. + + return Session(this, Handle); +} + +void +CurlHttpClient::ReleaseSession(CURL* Handle) +{ + ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); + + // Free any header list that was set + // curl_easy_reset will be called on next AllocSession, which cleans up the handle state. + // We just push the handle back to the pool. + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Response +CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::TransactPackage"); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + std::vector<std::pair<std::string, std::string>> OfferExtraHeaders; + OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); + OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); + + std::string FilterBody; + WriteCallbackData WriteData{.Body = &FilterBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + std::vector<std::pair<std::string, std::string>> PkgExtraHeaders; + PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); + PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); + + std::string PkgBody; + WriteCallbackData WriteData{.Body = &PkgBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + return {.StatusCode = HttpResponseCode(Result.StatusCode)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); + + return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_NOBODY, 1L); + + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE"); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + // Rebuild headers with content type + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + // Track requested content length from Range header (sum all ranges) + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + // Header callback that detects Content-Length and switches to file-backed storage when needed + struct DownloadHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + HttpContentType* ContentTypeOut = nullptr; + }; + + DownloadHeaderCallbackData DlHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + DlHdrData.Headers = &ResponseHeaders; + DlHdrData.PayloadFile = &PayloadFile; + DlHdrData.PayloadString = &PayloadString; + DlHdrData.TempFolderPath = &TempFolderPath; + DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + DlHdrData.Log = m_Log; + DlHdrData.BoundaryParser = &BoundaryParser; + DlHdrData.IsMultiRange = &IsMultiRangeResponse; + DlHdrData.ContentTypeOut = &ContentType; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + else if (Key == "Content-Type"sv) + { + *Data->IsMultiRange = Data->BoundaryParser->Init(Value); + if (!*Data->IsMultiRange) + { + *Data->ContentTypeOut = ParseContentType(Value); + } + } + else if (Key == "Content-Range"sv) + { + if (!*Data->IsMultiRange) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value); + if (Range.second != 0) + { + Data->BoundaryParser->Boundaries.push_back( + HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = *Data->ContentTypeOut}); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData); + + // Write callback that directs data to file or string + struct DownloadWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + }; + + DownloadWriteCallbackData DlWriteData; + DlWriteData.PayloadString = &PayloadString; + DlWriteData.PayloadFile = &PayloadFile; + DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + DlWriteData.TempFolderPath = &TempFolderPath; + DlWriteData.Log = m_Log; + DlWriteData.BoundaryParser = &BoundaryParser; + DlWriteData.IsMultiRange = &IsMultiRangeResponse; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->IsMultiRange) + { + Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes)); + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + // Handle resume logic + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const CurlResult& R) -> bool { + for (const auto& [K, V] : R.Headers) + { + if (K == "Content-Range") + { + return true; + } + if (K == "Accept-Ranges" && V == "bytes") + { + return true; + } + } + return false; + }; + + auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(R)) + { + return SupportsRanges(R); + } + return false; + }; + + if (ShouldResumeCheck(Res)) + { + // Find Content-Length + std::string ContentLengthValue; + for (const auto& [K, V] : Res.Headers) + { + if (K == "Content-Length") + { + ContentLengthValue = V; + break; + } + } + + if (!ContentLengthValue.empty()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + break; // No progress, abort + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session ResumeSess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + CURL* ResumeH = ResumeSess.Get(); + + curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()); + curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList); + curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); + + std::vector<std::pair<std::string, std::string>> ResumeHeaders; + + struct ResumeHeaderCbData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + }; + + ResumeHeaderCbData ResumeHdrData; + ResumeHdrData.Headers = &ResumeHeaders; + ResumeHdrData.PayloadFile = &PayloadFile; + ResumeHdrData.PayloadString = &PayloadString; + + auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<ResumeHeaderCbData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Range"sv) + { + if (Value.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) + { + const std::optional<uint64_t> Start = + ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() + : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) + { + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + } + } + return 0; + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(ResumeH, + CURLOPT_HEADERFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb)); + curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData); + curl_easy_setopt(ResumeH, + CURLOPT_WRITEFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData); + + Res = ResumeSess.Perform(); + Res.Headers = std::move(ResumeHeaders); + + curl_slist_free_all(ResumeHdrList); + } while (ShouldResumeCheck(Res)); + } + } + } + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + curl_slist_free_all(DlHeaders); + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Result), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); +} + +} // namespace zen |