diff options
Diffstat (limited to 'src/zenhttp')
| -rw-r--r-- | src/zenhttp/clients/httpclientcommon.h | 5 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.cpp | 112 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcpr.h | 19 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.cpp | 235 | ||||
| -rw-r--r-- | src/zenhttp/clients/httpclientcurl.h | 5 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 7 | ||||
| -rw-r--r-- | src/zenhttp/httpclient_test.cpp | 44 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 5 | ||||
| -rw-r--r-- | src/zenhttp/packageformat.cpp | 20 |
9 files changed, 377 insertions, 75 deletions
diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 5ed946541..e95e3a253 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -36,7 +36,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) = 0; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index f3082e0a2..a0f5cc38f 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -795,22 +795,97 @@ CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTyp } CprHttpClient::Response -CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +CprHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); - return CommonResponse( + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + cpr::Response Response = DoWithRetry( m_SessionId, - DoWithRetry(m_SessionId, - [&]() { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + [&]() { + PayloadString.clear(); + PayloadFile.reset(); - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); - return Sess.Post(); - }), - {}); + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + cpr::Response Response = Sess.Post({}, cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); } CprHttpClient::Response @@ -1012,13 +1087,18 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF { std::vector<std::pair<std::string, std::string>> ReceivedHeaders; auto HeaderCallback = [&](std::string header, intptr_t) { + if (RequestedContentLength != (uint64_t)-1 && RequestedContentLength > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + ZEN_DEBUG("Multirange request"); + } const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); - if (Header.first == "Content-Length"sv) + const std::string Key(Header.first); + if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); if (ContentLength.has_value()) { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) { PayloadFile = std::make_unique<detail::TempPayloadFile>(); std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); @@ -1036,7 +1116,7 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF } } } - else if (Header.first == "Content-Type") + else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) { IsMultiRangeResponse = BoundaryParser.Init(Header.second); if (!IsMultiRangeResponse) @@ -1044,7 +1124,7 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF ContentType = ParseContentType(Header.second); } } - else if (Header.first == "Content-Range") + else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) { if (!IsMultiRangeResponse) { @@ -1121,7 +1201,7 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF ReceivedHeaders.emplace_back(std::move(Header)); } - if (Header.first == "Content-Range"sv) + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Range") == 0) { if (Header.second.starts_with("bytes "sv)) { diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 752d91add..009e6fb7a 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -38,7 +38,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) override; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -104,15 +107,27 @@ private: CprSession->SetReadCallback({}); return Result; } - inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}, + std::optional<cpr::WriteCallback>&& Write = {}, + std::optional<cpr::HeaderCallback>&& Header = {}) { ZEN_TRACE_CPU("HttpClient::Impl::Post"); if (Read) { CprSession->SetReadCallback(std::move(Read.value())); } + if (Write) + { + CprSession->SetWriteCallback(std::move(Write.value())); + } + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } cpr::Response Result = CprSession->Post(); ZEN_TRACE("POST {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); CprSession->SetReadCallback({}); return Result; } diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index 3cb749018..341adc5f7 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -413,7 +413,7 @@ CurlHttpClient::ResponseWithPayload(std::string_view SessionId, for (const auto& [Key, Value] : Result.Headers) { - if (Key == "Content-Type") + if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) { const HttpContentType ContentType = ParseContentType(Value); ResponseBuffer.SetContentType(ContentType); @@ -522,7 +522,7 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp // Find Content-Length in headers for (const auto& [Key, Value] : Result.Headers) { - if (Key == "Content-Length") + if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) { std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); if (!ExpectedContentSize.has_value()) @@ -549,7 +549,7 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp // Check X-Jupiter-IoHash for (const auto& [Key, Value] : Result.Headers) { - if (Key == "X-Jupiter-IoHash") + if (StrCaseCompare(Key.c_str(), "X-Jupiter-IoHash") == 0) { IoHash ExpectedPayloadHash; if (IoHash::TryParse(Value, ExpectedPayloadHash)) @@ -571,7 +571,7 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp // Validate content-type specific payload for (const auto& [Key, Value] : Result.Headers) { - if (Key == "Content-Type") + if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) { if (Value == "application/x-ue-comp") { @@ -933,7 +933,16 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); - return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer}; + for (const auto& [Key, Value] : Result.Headers) + { + if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + { + ResponseBuffer.SetContentType(ParseContentType(Value)); + break; + } + } + + return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -1270,45 +1279,177 @@ CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTy } CurlHttpClient::Response -CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +CurlHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); - return CommonResponse( + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + CurlResult Result = DoWithRetry( m_SessionId, - DoWithRetry( - m_SessionId, - [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); + [&]() -> CurlResult { + PayloadString.clear(); + PayloadFile.reset(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); - curl_easy_setopt(H, CURLOPT_POST, 1L); - curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); - curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + struct PostHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + }; - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); + PostHeaderCallbackData PostHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + PostHdrData.Headers = &ResponseHeaders; + PostHdrData.PayloadFile = &PayloadFile; + PostHdrData.PayloadString = &PayloadString; + PostHdrData.TempFolderPath = &TempFolderPath; + PostHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + PostHdrData.Log = m_Log; - curl_slist_free_all(Headers); - return Result; - }), - {}); + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (StrCaseCompare(std::string(Key).c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (!Data->TempFolderPath->empty() && ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &PostHdrData); + + struct PostWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + }; + + PostWriteCallbackData PostWriteData; + PostWriteData.PayloadString = &PayloadString; + PostWriteData.PayloadFile = &PayloadFile; + PostWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + PostWriteData.TempFolderPath = &TempFolderPath; + PostWriteData.Log = m_Log; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &PostWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + curl_slist_free_all(Headers); + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, std::move(Result), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, {}); } CurlHttpClient::Response @@ -1616,19 +1757,21 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp size_t ColonPos = Line.find(':'); if (ColonPos != std::string_view::npos) { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); + std::string_view KeyView = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); - while (!Key.empty() && Key.back() == ' ') + while (!KeyView.empty() && KeyView.back() == ' ') { - Key.remove_suffix(1); + KeyView.remove_suffix(1); } while (!Value.empty() && Value.front() == ' ') { Value.remove_prefix(1); } - if (Key == "Content-Length"sv) + const std::string Key(KeyView); + + if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Value); if (ContentLength.has_value()) @@ -1652,7 +1795,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } } } - else if (Key == "Content-Type"sv) + else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) { *Data->IsMultiRange = Data->BoundaryParser->Init(Value); if (!*Data->IsMultiRange) @@ -1660,7 +1803,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp *Data->ContentTypeOut = ParseContentType(Value); } } - else if (Key == "Content-Range"sv) + else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) { if (!*Data->IsMultiRange) { @@ -1751,13 +1894,13 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto SupportsRanges = [](const CurlResult& R) -> bool { for (const auto& [K, V] : R.Headers) { - if (K == "Content-Range") + if (StrCaseCompare(K.c_str(), "Content-Range") == 0) { return true; } - if (K == "Accept-Ranges" && V == "bytes") + if (StrCaseCompare(K.c_str(), "Accept-Ranges") == 0) { - return true; + return V == "bytes"sv; } } return false; @@ -1781,7 +1924,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::string ContentLengthValue; for (const auto& [K, V] : Res.Headers) { - if (K == "Content-Length") + if (StrCaseCompare(K.c_str(), "Content-Length") == 0) { ContentLengthValue = V; break; @@ -1865,7 +2008,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp Value.remove_prefix(1); } - if (Key == "Content-Range"sv) + if (StrCaseCompare(std::string(Key).c_str(), "Content-Range") == 0) { if (Value.starts_with("bytes "sv)) { diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index 2a49ff308..871877863 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -36,7 +36,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) override; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 9baf4346e..deeeb6c85 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -440,9 +440,12 @@ HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType C } HttpClient::Response -HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, + CbObject Payload, + const HttpClient::KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { - return m_Inner->Post(Url, Payload, AdditionalHeader); + return m_Inner->Post(Url, Payload, AdditionalHeader, TempFolderPath); } HttpClient::Response diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 2d949c546..5f3ad2455 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -514,6 +514,50 @@ TEST_CASE("httpclient.download") } } +TEST_CASE("httpclient.post-streaming") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory PostDir; + + SUBCASE("POST CbObject with TempFolderPath stays in memory when response is small") + { + HttpClient Client = Fixture.MakeClient(); + + CbObjectWriter Writer; + Writer.AddBool("streaming", false); + Writer.AddString("mode", "memory"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path()); + CHECK(Resp.IsSuccess()); + IoBufferFileReference _; + CHECK(!Resp.ResponsePayload.GetFileReference(_)); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["streaming"].AsBool() == false); + CHECK_EQ(RoundTripped["mode"].AsString(), "memory"); + } + + SUBCASE("POST CbObject with TempFolderPath streams to file when response exceeds MaximumInMemoryDownloadSize") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + CbObjectWriter Writer; + Writer.AddBool("streaming", true); + Writer.AddString("mode", "file"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path()); + CHECK(Resp.IsSuccess()); + IoBufferFileReference _; + CHECK(Resp.ResponsePayload.GetFileReference(_)); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["streaming"].AsBool() == true); + CHECK_EQ(RoundTripped["mode"].AsString(), "file"); + } +} + TEST_CASE("httpclient.status-codes") { TestServerFixture Fixture; diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 2e21e3bd6..03c98af7e 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -315,7 +315,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}); - [[nodiscard]] Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}); [[nodiscard]] Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Post(std::string_view Url, const CompositeBuffer& Payload, diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index cbfe4d889..9c62c1f2d 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -575,13 +575,21 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else if (AttachmentSize > 0) { - // Make a copy of the buffer so the attachments don't reference the entire payload - IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); - ZEN_ASSERT(AttachmentBufferCopy); - ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); - AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); + IoBufferFileReference TestIfFileRef; + if (AttachmentBuffer.GetFileReference(TestIfFileRef)) + { + Attachments.emplace_back(CbAttachment(SharedBuffer{std::move(AttachmentBuffer)}, Entry.AttachmentHash)); + } + else + { + // Make a copy of the buffer so the attachments don't reference the entire payload + IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + ZEN_ASSERT(AttachmentBufferCopy); + ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); + AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); - Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash)); + Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash)); + } } else { |