aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/asynchttpclient_test.cpp315
-rw-r--r--src/zenhttp/clients/asynchttpclient.cpp1033
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp293
-rw-r--r--src/zenhttp/clients/httpclientcurl.h1
-rw-r--r--src/zenhttp/clients/httpclientcurlhelpers.h293
-rw-r--r--src/zenhttp/httpclient.cpp2
-rw-r--r--src/zenhttp/httpclient_test.cpp30
-rw-r--r--src/zenhttp/httpclientauth.cpp18
-rw-r--r--src/zenhttp/httpserver.cpp125
-rw-r--r--src/zenhttp/include/zenhttp/asynchttpclient.h123
-rw-r--r--src/zenhttp/include/zenhttp/httpclientauth.h3
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h14
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h29
-rw-r--r--src/zenhttp/include/zenhttp/httpstats.h6
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h8
-rw-r--r--src/zenhttp/include/zenhttp/localrefpolicy.h21
-rw-r--r--src/zenhttp/include/zenhttp/packageformat.h25
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h2
-rw-r--r--src/zenhttp/monitoring/httpstats.cpp3
-rw-r--r--src/zenhttp/packageformat.cpp548
-rw-r--r--src/zenhttp/servers/httpasio.cpp38
-rw-r--r--src/zenhttp/servers/httpparser.cpp414
-rw-r--r--src/zenhttp/servers/httpparser.h8
-rw-r--r--src/zenhttp/servers/httpplugin.cpp32
-rw-r--r--src/zenhttp/servers/httpsys.cpp45
-rw-r--r--src/zenhttp/servers/wsasio.cpp2
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp6
-rw-r--r--src/zenhttp/servers/wstest.cpp37
-rw-r--r--src/zenhttp/xmake.lua2
-rw-r--r--src/zenhttp/zenhttp.cpp3
30 files changed, 3044 insertions, 435 deletions
diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp
new file mode 100644
index 000000000..151863370
--- /dev/null
+++ b/src/zenhttp/asynchttpclient_test.cpp
@@ -0,0 +1,315 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/asynchttpclient.h>
+#include <zenhttp/httpserver.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/iobuffer.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+# include "servers/httpasio.h"
+
+# include <atomic>
+# include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+// Reusable test service for async client tests
+
+class AsyncHttpClientTestService : public HttpService
+{
+public:
+ AsyncHttpClientTestService()
+ {
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ IoBuffer Body = HttpReq.ReadPayload();
+ HttpContentType CT = HttpReq.RequestContentType();
+ HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body);
+ },
+ HttpVerb::kPost | HttpVerb::kPut);
+
+ m_Router.RegisterRoute(
+ "echo/method",
+ [](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ std::string_view Method = ToString(HttpReq.RequestVerb());
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method);
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "nocontent",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "json",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}");
+ },
+ HttpVerb::kGet);
+ }
+
+ virtual const char* BaseUri() const override { return "/api/async-test/"; }
+ virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); }
+
+private:
+ HttpRequestRouter m_Router;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct AsyncTestServerFixture
+{
+ AsyncHttpClientTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ Ref<HttpServer> Server;
+ std::thread ServerThread;
+ int Port = -1;
+
+ AsyncTestServerFixture()
+ {
+ Server = CreateHttpAsioServer(AsioConfig{});
+ Port = Server->Initialize(0, TmpDir.Path());
+ ZEN_ASSERT(Port != -1);
+ Server->RegisterService(TestService);
+ ServerThread = std::thread([this]() { Server->Run(false); });
+ }
+
+ ~AsyncTestServerFixture()
+ {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ }
+
+ AsyncHttpClient MakeClient(HttpClientSettings Settings = {}) { return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), Settings); }
+
+ AsyncHttpClient MakeClient(asio::io_context& IoContext, HttpClientSettings Settings = {})
+ {
+ return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), IoContext, Settings);
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+// Tests
+
+TEST_SUITE_BEGIN("http.asynchttpclient");
+
+TEST_CASE("asynchttpclient.future.verbs")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("GET returns 200 with expected body")
+ {
+ auto Future = Client.Get("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "GET");
+ }
+
+ SUBCASE("POST dispatches correctly")
+ {
+ auto Future = Client.Post("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "POST");
+ }
+
+ SUBCASE("PUT dispatches correctly")
+ {
+ auto Future = Client.Put("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "PUT");
+ }
+
+ SUBCASE("DELETE dispatches correctly")
+ {
+ auto Future = Client.Delete("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "DELETE");
+ }
+
+ SUBCASE("HEAD returns 200 with empty body")
+ {
+ auto Future = Client.Head("/api/async-test/echo/method");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), ""sv);
+ }
+}
+
+TEST_CASE("asynchttpclient.future.get")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ SUBCASE("simple GET with text response")
+ {
+ auto Future = Client.Get("/api/async-test/hello");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK);
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("GET returning JSON")
+ {
+ auto Future = Client.Get("/api/async-test/json");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "{\"ok\":true}");
+ }
+
+ SUBCASE("GET 204 NoContent")
+ {
+ auto Future = Client.Get("/api/async-test/nocontent");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent);
+ }
+}
+
+TEST_CASE("asynchttpclient.future.post.with.payload")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::string_view PayloadStr = "async payload data";
+ IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size());
+ Payload.SetContentType(ZenContentType::kText);
+
+ auto Future = Client.Post("/api/async-test/echo", Payload);
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "async payload data");
+}
+
+TEST_CASE("asynchttpclient.future.put.with.payload")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::string_view PutStr = "put payload";
+ IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size());
+ Payload.SetContentType(ZenContentType::kText);
+
+ auto Future = Client.Put("/api/async-test/echo", Payload);
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "put payload");
+}
+
+TEST_CASE("asynchttpclient.callback")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ std::promise<HttpClient::Response> Promise;
+ auto Future = Promise.get_future();
+
+ Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); });
+
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+}
+
+TEST_CASE("asynchttpclient.concurrent.requests")
+{
+ AsyncTestServerFixture Fixture;
+ AsyncHttpClient Client = Fixture.MakeClient();
+
+ // Fire multiple requests concurrently
+ auto Future1 = Client.Get("/api/async-test/hello");
+ auto Future2 = Client.Get("/api/async-test/json");
+ auto Future3 = Client.Post("/api/async-test/echo/method");
+ auto Future4 = Client.Delete("/api/async-test/echo/method");
+
+ auto Resp1 = Future1.get();
+ auto Resp2 = Future2.get();
+ auto Resp3 = Future3.get();
+ auto Resp4 = Future4.get();
+
+ CHECK(Resp1.IsSuccess());
+ CHECK_EQ(Resp1.AsText(), "hello world");
+
+ CHECK(Resp2.IsSuccess());
+ CHECK_EQ(Resp2.AsText(), "{\"ok\":true}");
+
+ CHECK(Resp3.IsSuccess());
+ CHECK_EQ(Resp3.AsText(), "POST");
+
+ CHECK(Resp4.IsSuccess());
+ CHECK_EQ(Resp4.AsText(), "DELETE");
+}
+
+TEST_CASE("asynchttpclient.external.io_context")
+{
+ AsyncTestServerFixture Fixture;
+
+ asio::io_context IoContext;
+ auto WorkGuard = asio::make_work_guard(IoContext);
+ std::thread IoThread([&IoContext]() { IoContext.run(); });
+
+ {
+ AsyncHttpClient Client = Fixture.MakeClient(IoContext);
+
+ auto Future = Client.Get("/api/async-test/hello");
+ auto Resp = Future.get();
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ WorkGuard.reset();
+ IoThread.join();
+}
+
+TEST_CASE("asynchttpclient.connection.error")
+{
+ // Connect to a port where nothing is listening
+ AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)});
+
+ auto Future = Client.Get("/should-fail");
+ auto Resp = Future.get();
+
+ CHECK_FALSE(Resp.IsSuccess());
+ CHECK(Resp.Error.has_value());
+ CHECK(Resp.Error->IsConnectionError());
+}
+
+TEST_SUITE_END();
+
+void
+asynchttpclient_test_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp
new file mode 100644
index 000000000..4b189af36
--- /dev/null
+++ b/src/zenhttp/clients/asynchttpclient.cpp
@@ -0,0 +1,1033 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/asynchttpclient.h>
+
+#include "httpclientcurlhelpers.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#include <asio/steady_timer.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// TransferContext: per-transfer state associated with each CURL easy handle
+
+struct TransferContext
+{
+ AsyncHttpCallback Callback;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ CurlWriteCallbackData WriteData;
+ CurlHeaderCallbackData HeaderData;
+ curl_slist* HeaderList = nullptr;
+
+ // For PUT/POST with payload: keep the data alive until transfer completes
+ IoBuffer PayloadBuffer;
+ CurlReadCallbackData ReadData;
+
+ TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback))
+ {
+ WriteData.Body = &Body;
+ HeaderData.Headers = &ResponseHeaders;
+ }
+
+ ~TransferContext()
+ {
+ if (HeaderList)
+ {
+ curl_slist_free_all(HeaderList);
+ }
+ }
+
+ TransferContext(const TransferContext&) = delete;
+ TransferContext& operator=(const TransferContext&) = delete;
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// AsyncHttpClient::Impl
+
+struct AsyncHttpClient::Impl
+{
+ Impl(std::string_view BaseUri, const HttpClientSettings& Settings)
+ : m_BaseUri(BaseUri)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ , m_Strand(asio::make_strand(m_IoContext))
+ , m_Timer(m_Strand)
+ {
+ Init();
+ m_WorkGuard.emplace(m_IoContext.get_executor());
+ m_IoThread = std::thread([this]() {
+ SetCurrentThreadName("async_http");
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (const std::exception& Ex)
+ {
+ ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what());
+ }
+ });
+ }
+
+ Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings)
+ : m_BaseUri(BaseUri)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ , m_Strand(asio::make_strand(m_IoContext))
+ , m_Timer(m_Strand)
+ {
+ Init();
+ }
+
+ ~Impl()
+ {
+ // Clean up curl state on the strand where all curl_multi operations
+ // are serialized. Use a promise to block until the cleanup handler
+ // has actually executed - essential for the external io_context case
+ // where we don't own the run loop.
+ std::promise<void> Done;
+ std::future<void> DoneFuture = Done.get_future();
+
+ asio::post(m_Strand, [this, &Done]() {
+ m_ShuttingDown = true;
+ m_Timer.cancel();
+
+ // Release all tracked sockets (don't close - curl owns the fds).
+ for (auto& [Fd, Info] : m_Sockets)
+ {
+ if (Info->Socket.is_open())
+ {
+ Info->Socket.cancel();
+ Info->Socket.release();
+ }
+ }
+ m_Sockets.clear();
+
+ for (auto& [Handle, Ctx] : m_Transfers)
+ {
+ curl_multi_remove_handle(m_Multi, Handle);
+ curl_easy_cleanup(Handle);
+ }
+ m_Transfers.clear();
+
+ for (CURL* Handle : m_HandlePool)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_HandlePool.clear();
+
+ Done.set_value();
+ });
+
+ // For owned io_context: release work guard so run() can return after
+ // processing the cleanup handler above.
+ m_WorkGuard.reset();
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ else
+ {
+ // External io_context: wait for the cleanup handler to complete.
+ DoneFuture.wait();
+ }
+
+ if (m_Multi)
+ {
+ curl_multi_cleanup(m_Multi);
+ }
+ }
+
+ LoggerRef Log() { return m_Log; }
+
+ void Init()
+ {
+ m_Multi = curl_multi_init();
+ if (!m_Multi)
+ {
+ throw std::runtime_error("curl_multi_init failed");
+ }
+
+ SetupMultiCallbacks();
+
+ if (m_Settings.SessionId == Oid::Zero)
+ {
+ m_SessionId = std::string(GetSessionIdString());
+ }
+ else
+ {
+ m_SessionId = m_Settings.SessionId.ToString();
+ }
+ }
+
+ // -- Handle pool -----------------------------------------------------
+
+ CURL* AllocHandle()
+ {
+ if (!m_HandlePool.empty())
+ {
+ CURL* Handle = m_HandlePool.back();
+ m_HandlePool.pop_back();
+ curl_easy_reset(Handle);
+ return Handle;
+ }
+ CURL* Handle = curl_easy_init();
+ if (!Handle)
+ {
+ throw std::runtime_error("curl_easy_init failed");
+ }
+ return Handle;
+ }
+
+ void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); }
+
+ // -- Configure a handle with common settings -------------------------
+ // Called only from DoAsync* lambdas running on the strand.
+
+ void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters)
+ {
+ // Build URL
+ ExtendableStringBuilder<256> Url;
+ BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Unix domain socket
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath);
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str());
+ }
+
+ // Timeouts
+ if (m_Settings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_Settings.ConnectTimeout.count()));
+ }
+ if (m_Settings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (m_Settings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // SSL
+ if (m_Settings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!m_Settings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str());
+ }
+
+ // Verbose/debug
+ if (m_Settings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ }
+
+ // Thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (m_Settings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+ }
+
+ // -- Access token ----------------------------------------------------
+
+ std::optional<std::string> GetAccessToken()
+ {
+ if (!m_Settings.AccessTokenProvider.has_value())
+ {
+ return {};
+ }
+ {
+ RwLock::SharedLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ }
+ RwLock::ExclusiveLockScope _(m_AccessTokenLock);
+ if (!m_CachedAccessToken.NeedsRefresh())
+ {
+ return m_CachedAccessToken.GetValue();
+ }
+ HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()();
+ if (!NewToken.IsValid())
+ {
+ ZEN_WARN("AsyncHttpClient: failed to refresh access token, retrying once");
+ NewToken = m_Settings.AccessTokenProvider.value()();
+ }
+ if (NewToken.IsValid())
+ {
+ m_CachedAccessToken = NewToken;
+ return m_CachedAccessToken.GetValue();
+ }
+ ZEN_WARN("AsyncHttpClient: access token provider returned invalid token");
+ return {};
+ }
+
+ // -- Submit a transfer -----------------------------------------------
+
+ void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer");
+ // Setup write/header callbacks
+ curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData);
+ curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData);
+
+ m_Transfers[Handle] = std::move(Ctx);
+
+ CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle);
+ if (Mc != CURLM_OK)
+ {
+ auto Stolen = std::move(m_Transfers[Handle]);
+ m_Transfers.erase(Handle);
+ ReleaseHandle(Handle);
+
+ HttpClient::Response ErrorResponse;
+ ErrorResponse.Error =
+ HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError,
+ .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))};
+ asio::post(m_IoContext,
+ [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); });
+ return;
+ }
+ }
+
+ // -- Socket-action integration ---------------------------------------
+ //
+ // curl_multi drives I/O via two callbacks:
+ // - SocketCallback: curl tells us which sockets to watch for read/write
+ // - TimerCallback: curl tells us when to fire a timeout
+ //
+ // On each socket event or timeout we call curl_multi_socket_action(),
+ // then drain completed transfers via curl_multi_info_read().
+
+ // Per-socket state: wraps the native fd in an ASIO socket for async_wait.
+ struct SocketInfo
+ {
+ asio::ip::tcp::socket Socket;
+ int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT
+
+ explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {}
+ };
+
+ // Static thunks registered with curl_multi ----------------------------
+
+ static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr)
+ {
+ ZEN_UNUSED(Easy);
+ auto* Self = static_cast<Impl*>(UserPtr);
+ Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr));
+ return 0;
+ }
+
+ static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr)
+ {
+ ZEN_UNUSED(Multi);
+ auto* Self = static_cast<Impl*>(UserPtr);
+ Self->OnCurlTimer(TimeoutMs);
+ return 0;
+ }
+
+ void SetupMultiCallbacks()
+ {
+ curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback);
+ curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this);
+ curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback);
+ curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this);
+ }
+
+ // Called by curl when socket watch state changes ---------------------
+
+ void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info)
+ {
+ if (Action == CURL_POLL_REMOVE)
+ {
+ if (Info)
+ {
+ // Cancel pending async_wait ops before releasing the fd.
+ // curl owns the fd, so we must release() rather than close().
+ Info->Socket.cancel();
+ if (Info->Socket.is_open())
+ {
+ Info->Socket.release();
+ }
+ m_Sockets.erase(Fd);
+ }
+ return;
+ }
+
+ if (!Info)
+ {
+ // New socket - wrap the native fd in an ASIO socket.
+ auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext));
+ Info = It->second.get();
+
+ asio::error_code Ec;
+ // Determine protocol from the fd (v4 vs v6). Default to v4.
+ Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec);
+ if (Ec)
+ {
+ // Try v6 as fallback
+ Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec);
+ }
+ if (Ec)
+ {
+ ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message());
+ m_Sockets.erase(Fd);
+ return;
+ }
+
+ curl_multi_assign(m_Multi, Fd, Info);
+ }
+
+ Info->WatchFlags = Action;
+ SetSocketWatch(Fd, Info);
+ }
+
+ void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info)
+ {
+ // Cancel any pending wait before issuing a new one.
+ Info->Socket.cancel();
+
+ if (Info->WatchFlags & CURL_POLL_IN)
+ {
+ Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ OnSocketReady(Fd, CURL_CSELECT_IN);
+ }));
+ }
+
+ if (Info->WatchFlags & CURL_POLL_OUT)
+ {
+ Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ OnSocketReady(Fd, CURL_CSELECT_OUT);
+ }));
+ }
+ }
+
+ void OnSocketReady(curl_socket_t Fd, int CurlAction)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady");
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning);
+ CheckCompleted();
+
+ // Re-arm the watch if the socket is still tracked.
+ auto It = m_Sockets.find(Fd);
+ if (It != m_Sockets.end())
+ {
+ SetSocketWatch(Fd, It->second.get());
+ }
+ }
+
+ // Called by curl when it wants a timeout ------------------------------
+
+ void OnCurlTimer(long TimeoutMs)
+ {
+ m_Timer.cancel();
+
+ if (TimeoutMs < 0)
+ {
+ // curl says "no timeout needed"
+ return;
+ }
+
+ if (TimeoutMs == 0)
+ {
+ // curl wants immediate action - run it directly on the strand.
+ asio::post(m_Strand, [this]() {
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning);
+ CheckCompleted();
+ });
+ return;
+ }
+
+ m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs));
+ m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) {
+ if (Ec || m_ShuttingDown)
+ {
+ return;
+ }
+ ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout");
+ int StillRunning = 0;
+ curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning);
+ CheckCompleted();
+ }));
+ }
+
+ // Drain completed transfers from curl_multi --------------------------
+
+ void CheckCompleted()
+ {
+ int MsgsLeft = 0;
+ CURLMsg* Msg = nullptr;
+ while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr)
+ {
+ if (Msg->msg != CURLMSG_DONE)
+ {
+ continue;
+ }
+
+ CURL* Handle = Msg->easy_handle;
+ CURLcode Result = Msg->data.result;
+
+ curl_multi_remove_handle(m_Multi, Handle);
+
+ auto It = m_Transfers.find(Handle);
+ if (It == m_Transfers.end())
+ {
+ ReleaseHandle(Handle);
+ continue;
+ }
+
+ std::unique_ptr<TransferContext> Ctx = std::move(It->second);
+ m_Transfers.erase(It);
+
+ CompleteTransfer(Handle, Result, std::move(Ctx));
+ }
+ }
+
+ void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx)
+ {
+ ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer");
+ // Extract result info
+ long StatusCode = 0;
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+
+ ReleaseHandle(Handle);
+
+ // Build response
+ HttpClient::Response Response;
+ Response.StatusCode = HttpResponseCode(StatusCode);
+ Response.UploadedBytes = static_cast<int64_t>(UpBytes);
+ Response.DownloadedBytes = static_cast<int64_t>(DownBytes);
+ Response.ElapsedSeconds = Elapsed;
+ Response.Header = BuildHeaderMap(Ctx->ResponseHeaders);
+
+ if (CurlResult != CURLE_OK)
+ {
+ const char* ErrorMsg = curl_easy_strerror(CurlResult);
+
+ if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg);
+ }
+
+ if (!Ctx->Body.empty())
+ {
+ Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size());
+ }
+
+ Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)};
+ }
+ else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty())
+ {
+ // No payload
+ }
+ else
+ {
+ IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size());
+ ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders);
+
+ const HttpResponseCode Code = HttpResponseCode(StatusCode);
+ if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound)
+ {
+ ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri);
+ }
+
+ Response.ResponsePayload = std::move(PayloadBuffer);
+ }
+
+ // Dispatch the user callback off the strand so a slow callback
+ // cannot starve the curl_multi poll loop.
+ asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable {
+ try
+ {
+ Cb(std::move(Response));
+ }
+ catch (const std::exception& Ex)
+ {
+ auto Log = [&]() -> LoggerRef { return LogRef; };
+ ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what());
+ }
+ });
+ }
+
+ // -- Async verb implementations --------------------------------------
+
+ void DoAsyncGet(std::string Url,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Get");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Head");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Delete");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE");
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPost(std::string Url,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Post");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_POST, 1L);
+ curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPostWithPayload(std::string Url,
+ IoBuffer Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Payload = std::move(Payload),
+ ContentType,
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, {});
+ curl_easy_setopt(Handle, CURLOPT_POST, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->PayloadBuffer = std::move(Payload);
+ Ctx->HeaderList =
+ BuildHeaderList(AdditionalHeader,
+ m_SessionId,
+ GetAccessToken(),
+ {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))});
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ // Set up read callback for payload data
+ Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData());
+ Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize();
+ Ctx->ReadData.Offset = 0;
+
+ curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize()));
+ curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPutWithPayload(std::string Url,
+ IoBuffer Payload,
+ AsyncHttpCallback Callback,
+ HttpClient::KeyValueMap AdditionalHeader,
+ HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand,
+ [this,
+ Url = std::move(Url),
+ Payload = std::move(Payload),
+ Callback = std::move(Callback),
+ AdditionalHeader = std::move(AdditionalHeader),
+ Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Put");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+ Ctx->PayloadBuffer = std::move(Payload);
+ Ctx->HeaderList = BuildHeaderList(
+ AdditionalHeader,
+ m_SessionId,
+ GetAccessToken(),
+ {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))});
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData());
+ Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize();
+ Ctx->ReadData.Offset = 0;
+
+ curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize()));
+ curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters)
+ {
+ asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable {
+ ZEN_TRACE_CPU("AsyncHttpClient::Put");
+ if (m_ShuttingDown)
+ {
+ return;
+ }
+ CURL* Handle = AllocHandle();
+ ConfigureHandle(Handle, Url, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ auto Ctx = std::make_unique<TransferContext>(std::move(Callback));
+
+ HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList);
+
+ SubmitTransfer(Handle, std::move(Ctx));
+ });
+ }
+
+ // -- Members ---------------------------------------------------------
+
+ std::string m_BaseUri;
+ HttpClientSettings m_Settings;
+ LoggerRef m_Log;
+ std::string m_SessionId;
+ std::string m_UnixSocketPathUtf8;
+
+ // io_context and strand - all curl_multi operations are serialized on the
+ // strand, making this safe even when the io_context has multiple threads.
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ asio::strand<asio::io_context::executor_type> m_Strand;
+ std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // curl_multi and socket-action state
+ CURLM* m_Multi = nullptr;
+ std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers;
+ std::vector<CURL*> m_HandlePool;
+ std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets;
+ asio::steady_timer m_Timer;
+ bool m_ShuttingDown = false;
+
+ // Access token cache
+ RwLock m_AccessTokenLock;
+ HttpClientAccessToken m_CachedAccessToken;
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// AsyncHttpClient public API
+
+AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(BaseUri, Settings))
+{
+}
+
+AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings))
+{
+}
+
+AsyncHttpClient::~AsyncHttpClient() = default;
+
+// -- Callback-based API --------------------------------------------------
+
+void
+AsyncHttpClient::AsyncGet(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPost(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader)
+{
+ m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader);
+}
+
+void
+AsyncHttpClient::AsyncPut(std::string_view Url,
+ const IoBuffer& Payload,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters);
+}
+
+void
+AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters)
+{
+ m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters);
+}
+
+// -- Future-based API ----------------------------------------------------
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncGet(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncHead(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncDelete(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ Payload,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPost(
+ Url,
+ Payload,
+ ContentType,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPut(
+ Url,
+ Payload,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ AdditionalHeader,
+ Parameters);
+ return Future;
+}
+
+std::future<HttpClient::Response>
+AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ auto Promise = std::make_shared<std::promise<Response>>();
+ auto Future = Promise->get_future();
+ AsyncPut(
+ Url,
+ [Promise](Response R) { Promise->set_value(std::move(R)); },
+ Parameters);
+ return Future;
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
index d150b44c6..56b9c39c5 100644
--- a/src/zenhttp/clients/httpclientcurl.cpp
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpclientcurl.h"
+#include "httpclientcurlhelpers.h"
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
@@ -29,153 +30,7 @@ static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
//////////////////////////////////////////////////////////////////////////
-static HttpClientErrorCode
-MapCurlError(CURLcode Code)
-{
- switch (Code)
- {
- case CURLE_OK:
- return HttpClientErrorCode::kOK;
- case CURLE_COULDNT_CONNECT:
- return HttpClientErrorCode::kConnectionFailure;
- case CURLE_COULDNT_RESOLVE_HOST:
- return HttpClientErrorCode::kHostResolutionFailure;
- case CURLE_COULDNT_RESOLVE_PROXY:
- return HttpClientErrorCode::kProxyResolutionFailure;
- case CURLE_RECV_ERROR:
- return HttpClientErrorCode::kNetworkReceiveError;
- case CURLE_SEND_ERROR:
- return HttpClientErrorCode::kNetworkSendFailure;
- case CURLE_OPERATION_TIMEDOUT:
- return HttpClientErrorCode::kOperationTimedOut;
- case CURLE_SSL_CONNECT_ERROR:
- return HttpClientErrorCode::kSSLConnectError;
- case CURLE_SSL_CERTPROBLEM:
- return HttpClientErrorCode::kSSLCertificateError;
- case CURLE_PEER_FAILED_VERIFICATION:
- return HttpClientErrorCode::kSSLCACertError;
- case CURLE_SSL_CIPHER:
- case CURLE_SSL_ENGINE_NOTFOUND:
- case CURLE_SSL_ENGINE_SETFAILED:
- return HttpClientErrorCode::kGenericSSLError;
- case CURLE_ABORTED_BY_CALLBACK:
- return HttpClientErrorCode::kRequestCancelled;
- default:
- return HttpClientErrorCode::kOtherError;
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-//
-// Curl callback helpers
-
-struct WriteCallbackData
-{
- std::string* Body = nullptr;
- std::function<bool()>* CheckIfAbortFunction = nullptr;
-};
-
-static size_t
-CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<WriteCallbackData*>(UserData);
- size_t TotalBytes = Size * Nmemb;
-
- if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
- {
- return 0; // Signal abort to curl
- }
-
- Data->Body->append(Ptr, TotalBytes);
- return TotalBytes;
-}
-
-struct HeaderCallbackData
-{
- std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
-};
-
-// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
-// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
-static std::optional<std::pair<std::string_view, std::string_view>>
-ParseHeaderLine(std::string_view Line)
-{
- while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
- {
- Line.remove_suffix(1);
- }
-
- if (Line.empty())
- {
- return std::nullopt;
- }
-
- size_t ColonPos = Line.find(':');
- if (ColonPos == std::string_view::npos)
- {
- return std::nullopt;
- }
-
- std::string_view Key = Line.substr(0, ColonPos);
- std::string_view Value = Line.substr(ColonPos + 1);
-
- while (!Key.empty() && Key.back() == ' ')
- {
- Key.remove_suffix(1);
- }
- while (!Value.empty() && Value.front() == ' ')
- {
- Value.remove_prefix(1);
- }
-
- return std::pair{Key, Value};
-}
-
-static size_t
-CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<HeaderCallbackData*>(UserData);
- size_t TotalBytes = Size * Nmemb;
-
- if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
- {
- auto& [Key, Value] = *Header;
- Data->Headers->emplace_back(std::string(Key), std::string(Value));
- }
-
- return TotalBytes;
-}
-
-struct ReadCallbackData
-{
- const uint8_t* DataPtr = nullptr;
- size_t DataSize = 0;
- size_t Offset = 0;
- std::function<bool()>* CheckIfAbortFunction = nullptr;
-};
-
-static size_t
-CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
-{
- auto* Data = static_cast<ReadCallbackData*>(UserData);
- size_t MaxRead = Size * Nmemb;
-
- if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
- {
- return CURL_READFUNC_ABORT;
- }
-
- size_t Remaining = Data->DataSize - Data->Offset;
- size_t ToRead = std::min(MaxRead, Remaining);
-
- if (ToRead > 0)
- {
- memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
- Data->Offset += ToRead;
- }
-
- return ToRead;
-}
+// Curl callback helpers and shared utilities are in httpclientcurlhelpers.h
struct StreamReadCallbackData
{
@@ -281,120 +136,6 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi
//////////////////////////////////////////////////////////////////////////
-static std::pair<std::string, std::string>
-HeaderContentType(ZenContentType ContentType)
-{
- return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
-}
-
-static curl_slist*
-BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
- std::string_view SessionId,
- const std::optional<std::string>& AccessToken,
- const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
-{
- curl_slist* Headers = nullptr;
-
- for (const auto& [Key, Value] : *AdditionalHeader)
- {
- ExtendableStringBuilder<64> HeaderLine;
- HeaderLine << Key << ": " << Value;
- Headers = curl_slist_append(Headers, HeaderLine.c_str());
- }
-
- if (!SessionId.empty())
- {
- ExtendableStringBuilder<64> SessionHeader;
- SessionHeader << "UE-Session: " << SessionId;
- Headers = curl_slist_append(Headers, SessionHeader.c_str());
- }
-
- if (AccessToken.has_value())
- {
- ExtendableStringBuilder<128> AuthHeader;
- AuthHeader << "Authorization: " << AccessToken.value();
- Headers = curl_slist_append(Headers, AuthHeader.c_str());
- }
-
- for (const auto& [Key, Value] : ExtraHeaders)
- {
- ExtendableStringBuilder<128> HeaderLine;
- HeaderLine << Key << ": " << Value;
- Headers = curl_slist_append(Headers, HeaderLine.c_str());
- }
-
- return Headers;
-}
-
-static HttpClient::KeyValueMap
-BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
-{
- HttpClient::KeyValueMap HeaderMap;
- for (const auto& [Key, Value] : Headers)
- {
- HeaderMap->insert_or_assign(Key, Value);
- }
- return HeaderMap;
-}
-
-// Scans response headers for Content-Type and applies it to the buffer.
-static void
-ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
-{
- for (const auto& [Key, Value] : Headers)
- {
- if (StrCaseCompare(Key, "Content-Type") == 0)
- {
- Buffer.SetContentType(ParseContentType(Value));
- break;
- }
- }
-}
-
-static void
-AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
-{
- static constexpr char HexDigits[] = "0123456789ABCDEF";
- static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
-
- for (char C : Input)
- {
- if (Unreserved.Contains(C))
- {
- Out.Append(C);
- }
- else
- {
- uint8_t Byte = static_cast<uint8_t>(C);
- char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
- Out.Append(std::string_view(Encoded, 3));
- }
- }
-}
-
-static void
-BuildUrlWithParameters(StringBuilderBase& Url,
- std::string_view BaseUrl,
- std::string_view ResourcePath,
- const HttpClient::KeyValueMap& Parameters)
-{
- Url.Append(BaseUrl);
- Url.Append(ResourcePath);
-
- if (!Parameters->empty())
- {
- char Separator = '?';
- for (const auto& [Key, Value] : *Parameters)
- {
- Url.Append(Separator);
- AppendUrlEncoded(Url, Key);
- Url.Append('=');
- AppendUrlEncoded(Url, Value);
- Separator = '&';
- }
- }
-}
-
//////////////////////////////////////////////////////////////////////////
CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
@@ -440,9 +181,9 @@ CurlHttpClient::CurlResult
CurlHttpClient::Session::PerformWithResponseCallbacks()
{
std::string Body;
- WriteCallbackData WriteData{.Body = &Body,
+ CurlWriteCallbackData WriteData{.Body = &Body,
.CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr};
- HeaderCallbackData HdrData{};
+ CurlHeaderCallbackData HdrData{};
std::vector<std::pair<std::string, std::string>> ResponseHeaders;
HdrData.Headers = &ResponseHeaders;
@@ -487,6 +228,13 @@ CurlHttpClient::Session::Perform()
curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+ char* EffectiveUrl = nullptr;
+ curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl);
+ if (EffectiveUrl)
+ {
+ Result.Url = EffectiveUrl;
+ }
+
return Result;
}
@@ -553,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId,
if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
{
- ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'",
SessionId,
+ Result.Url,
static_cast<int>(Result.ErrorCode),
Result.ErrorMessage);
}
@@ -702,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result)
case CURLE_RECV_ERROR:
case CURLE_SEND_ERROR:
case CURLE_OPERATION_TIMEDOUT:
+ case CURLE_PARTIAL_FILE:
return true;
default:
return false;
@@ -748,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult
{
if (Result.ErrorCode != CURLE_OK)
{
- ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}",
+ ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}",
SessionId,
static_cast<int>(MapCurlError(Result.ErrorCode)),
Result.ErrorMessage,
+ static_cast<int>(Result.ErrorCode),
Attempt,
m_ConnectionSettings.RetryCount + 1);
}
@@ -998,9 +749,9 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu
curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
- ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
- .DataSize = Payload.GetSize(),
- .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
@@ -1367,9 +1118,9 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV
return Sess.PerformWithResponseCallbacks();
}
- ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
- .DataSize = Payload.GetSize(),
- .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
index bdeb46633..ea9193e65 100644
--- a/src/zenhttp/clients/httpclientcurl.h
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -73,6 +73,7 @@ private:
int64_t DownloadedBytes = 0;
CURLcode ErrorCode = CURLE_OK;
std::string ErrorMessage;
+ std::string Url;
};
struct Session
diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h
new file mode 100644
index 000000000..0605a30f6
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurlhelpers.h
@@ -0,0 +1,293 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// Shared helpers for curl-based HTTP client implementations (sync and async).
+// This is an internal header, not part of the public API.
+
+#include <zencore/string.h>
+
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Error mapping
+
+inline HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback data structures and callbacks
+
+struct CurlWriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+inline size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct CurlHeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value.
+// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines).
+inline std::optional<std::pair<std::string_view, std::string_view>>
+ParseHeaderLine(std::string_view Line)
+{
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return std::nullopt;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ return std::nullopt;
+ }
+
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ return std::pair{Key, Value};
+}
+
+inline size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)))
+ {
+ auto& [Key, Value] = *Header;
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct CurlReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+inline size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<CurlReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// URL and header construction
+
+inline void
+AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input)
+{
+ static constexpr char HexDigits[] = "0123456789ABCDEF";
+ static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~");
+
+ for (char C : Input)
+ {
+ if (Unreserved.Contains(C))
+ {
+ Out.Append(C);
+ }
+ else
+ {
+ uint8_t Byte = static_cast<uint8_t>(C);
+ char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]};
+ Out.Append(std::string_view(Encoded, 3));
+ }
+ }
+}
+
+inline void
+BuildUrlWithParameters(StringBuilderBase& Url,
+ std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClient::KeyValueMap& Parameters)
+{
+ Url.Append(BaseUrl);
+ Url.Append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ Url.Append(Separator);
+ AppendUrlEncoded(Url, Key);
+ Url.Append('=');
+ AppendUrlEncoded(Url, Value);
+ Separator = '&';
+ }
+ }
+}
+
+inline std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+inline curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<std::string>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ ExtendableStringBuilder<64> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ ExtendableStringBuilder<64> SessionHeader;
+ SessionHeader << "UE-Session: " << SessionId;
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken.has_value())
+ {
+ ExtendableStringBuilder<128> AuthHeader;
+ AuthHeader << "Authorization: " << AccessToken.value();
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ ExtendableStringBuilder<128> HeaderLine;
+ HeaderLine << Key << ": " << Value;
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+inline HttpClient::KeyValueMap
+BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+ return HeaderMap;
+}
+
+// Scans response headers for Content-Type and applies it to the buffer.
+inline void
+ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers)
+{
+ for (const auto& [Key, Value] : Headers)
+ {
+ if (StrCaseCompare(Key, "Content-Type") == 0)
+ {
+ Buffer.SetContentType(ParseContentType(Value));
+ break;
+ }
+ }
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index ace7a3c7f..3da8a9220 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -520,7 +520,7 @@ MeasureLatency(HttpClient& Client, std::string_view Url)
ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url));
// Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable.
- // Bail out immediately — retrying will just burn the connect timeout each time.
+ // Bail out immediately - retrying will just burn the connect timeout each time.
if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError())
{
break;
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
index af653cbb2..67cbaea9b 100644
--- a/src/zenhttp/httpclient_test.cpp
+++ b/src/zenhttp/httpclient_test.cpp
@@ -194,7 +194,7 @@ public:
"slow",
[](HttpRouterRequest& Req) {
Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
- Sleep(2000);
+ Sleep(100);
Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response");
});
},
@@ -750,7 +750,9 @@ TEST_CASE("httpclient.error-handling")
{
SUBCASE("Connection refused")
{
- HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClientSettings Settings;
+ Settings.ConnectTimeout = std::chrono::milliseconds(200);
+ HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {});
HttpClient::Response Resp = Client.Get("/api/test/hello");
CHECK(!Resp.IsSuccess());
CHECK(Resp.Error.has_value());
@@ -760,7 +762,7 @@ TEST_CASE("httpclient.error-handling")
{
TestServerFixture Fixture;
HttpClientSettings Settings;
- Settings.Timeout = std::chrono::milliseconds(500);
+ Settings.Timeout = std::chrono::milliseconds(50);
HttpClient Client = Fixture.MakeClient(Settings);
HttpClient::Response Resp = Client.Get("/api/test/slow");
@@ -970,7 +972,9 @@ TEST_CASE("httpclient.measurelatency")
SUBCASE("Failed measurement against unreachable port")
{
- HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {});
+ HttpClientSettings Settings;
+ Settings.ConnectTimeout = std::chrono::milliseconds(200);
+ HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {});
LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello");
CHECK(!Result.Success);
CHECK(!Result.FailureReason.empty());
@@ -1144,7 +1148,7 @@ struct FaultTcpServer
~FaultTcpServer()
{
// io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this
- // thread — ASIO I/O objects are not safe for concurrent access and the io_context
+ // thread - ASIO I/O objects are not safe for concurrent access and the io_context
// thread may be touching the acceptor in StartAccept().
m_IoContext.stop();
if (m_Thread.joinable())
@@ -1498,7 +1502,7 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
std::atomic<bool> StallActive{true};
FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) {
DrainHttpRequest(Socket);
- // Stop reading body — TCP window will fill and client send will stall
+ // Stop reading body - TCP window will fill and client send will stall
while (StallActive.load())
{
std::this_thread::sleep_for(std::chrono::milliseconds(50));
@@ -1735,21 +1739,21 @@ TEST_CASE("httpclient.uri_decoding")
TestServerFixture Fixture;
HttpClient Client = Fixture.MakeClient();
- // URI without encoding — should pass through unchanged
+ // URI without encoding - should pass through unchanged
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt");
}
- // Percent-encoded space — server should see decoded path
+ // Percent-encoded space - server should see decoded path
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt");
}
- // Percent-encoded slash (%2F) — should be decoded to /
+ // Percent-encoded slash (%2F) - should be decoded to /
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt");
REQUIRE(Resp.IsSuccess());
@@ -1763,21 +1767,21 @@ TEST_CASE("httpclient.uri_decoding")
CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt");
}
- // No capture — echo/uri route returns just RelativeUri
+ // No capture - echo/uri route returns just RelativeUri
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "echo/uri");
}
- // Literal percent that is not an escape (%ZZ) — should be kept as-is
+ // Literal percent that is not an escape (%ZZ) - should be kept as-is
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt");
REQUIRE(Resp.IsSuccess());
CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt");
}
- // Query params — raw values are returned as-is from GetQueryParams
+ // Query params - raw values are returned as-is from GetQueryParams
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test");
REQUIRE(Resp.IsSuccess());
@@ -1788,7 +1792,7 @@ TEST_CASE("httpclient.uri_decoding")
{
HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3");
REQUIRE(Resp.IsSuccess());
- // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly
+ // GetQueryParams returns raw (still-encoded) values - callers must Decode() explicitly
CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3");
}
diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp
index c42841922..0432e50ef 100644
--- a/src/zenhttp/httpclientauth.cpp
+++ b/src/zenhttp/httpclientauth.cpp
@@ -94,7 +94,8 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Unattended,
bool Quiet,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
Stopwatch Timer;
@@ -117,8 +118,9 @@ namespace zen { namespace httpclientauth {
}
});
- const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}",
+ const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}",
OidcExecutablePath,
+ IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl",
CloudHost,
AuthTokenPath,
Unattended ? "true"sv : "false"sv);
@@ -193,7 +195,7 @@ namespace zen { namespace httpclientauth {
}
else
{
- ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode);
+ ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode);
}
return HttpClientAccessToken{};
}
@@ -202,9 +204,10 @@ namespace zen { namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden)
+ bool Hidden,
+ bool IsHordeUrl)
{
- HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
if (InitialToken.IsValid())
{
return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath),
@@ -212,12 +215,13 @@ namespace zen { namespace httpclientauth {
Token = InitialToken,
Quiet,
Unattended,
- Hidden]() mutable {
+ Hidden,
+ IsHordeUrl]() mutable {
if (!Token.NeedsRefresh())
{
return std::move(Token);
}
- return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden);
+ return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl);
};
}
return {};
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index e05c9815f..03117ee6c 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -266,10 +266,10 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
return false;
}
- const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim));
- const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1));
+ const auto Start = ParseInt<uint64_t>(Token.substr(0, Delim));
+ const auto End = ParseInt<uint64_t>(Token.substr(Delim + 1));
- if (Start.has_value() && End.has_value() && End.value() > Start.value())
+ if (Start.has_value() && End.has_value() && End.value() >= Start.value())
{
Ranges.push_back({.Start = Start.value(), .End = End.value()});
}
@@ -286,6 +286,45 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
return Count != Ranges.size();
}
+MultipartByteRangesResult
+BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges)
+{
+ Oid::String_t BoundaryStr;
+ Oid::NewOid().ToString(BoundaryStr);
+ std::string_view Boundary(BoundaryStr, Oid::StringLength);
+
+ const uint64_t TotalSize = Data.GetSize();
+
+ std::vector<IoBuffer> Parts;
+ Parts.reserve(Ranges.size() * 2 + 1);
+
+ for (const HttpRange& Range : Ranges)
+ {
+ uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1;
+ if (RangeEnd >= TotalSize || Range.Start > RangeEnd)
+ {
+ return {};
+ }
+
+ uint64_t RangeSize = 1 + (RangeEnd - Range.Start);
+
+ std::string PartHeader = fmt::format("\r\n--{}\r\nContent-Type: application/octet-stream\r\nContent-Range: bytes {}-{}/{}\r\n\r\n",
+ Boundary,
+ Range.Start,
+ RangeEnd,
+ TotalSize);
+ Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(PartHeader.data(), PartHeader.size()));
+
+ IoBuffer RangeData(Data, Range.Start, RangeSize);
+ Parts.push_back(RangeData);
+ }
+
+ std::string ClosingBoundary = fmt::format("\r\n--{}--", Boundary);
+ Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(ClosingBoundary.data(), ClosingBoundary.size()));
+
+ return {.Parts = std::move(Parts), .ContentType = fmt::format("multipart/byteranges; boundary={}", Boundary)};
+}
+
//////////////////////////////////////////////////////////////////////////
const std::string_view
@@ -479,6 +518,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
return Ref<IHttpPackageHandler>();
}
+bool
+HttpService::AcceptsLocalFileReferences() const
+{
+ return false;
+}
+
+const ILocalRefPolicy*
+HttpService::GetLocalRefPolicy() const
+{
+ return nullptr;
+}
+
//////////////////////////////////////////////////////////////////////////
HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
@@ -552,6 +603,56 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType
}
void
+HttpServerRequest::WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges)
+{
+ if (Ranges.empty())
+ {
+ WriteResponse(HttpResponseCode::OK, ContentType, IoBuffer(Data));
+ return;
+ }
+
+ if (Ranges.size() == 1)
+ {
+ const HttpRange& Range = Ranges[0];
+ const uint64_t TotalSize = Data.GetSize();
+ // ~uint64_t(0) is the sentinel meaning "end of file" (suffix range).
+ const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1;
+
+ if (RangeEnd >= TotalSize || Range.Start > RangeEnd)
+ {
+ m_ContentRangeHeader = fmt::format("bytes */{}", TotalSize);
+ WriteResponse(HttpResponseCode::RangeNotSatisfiable);
+ return;
+ }
+
+ const uint64_t RangeSize = 1 + (RangeEnd - Range.Start);
+ IoBuffer RangeBuf(Data, Range.Start, RangeSize);
+
+ m_ContentRangeHeader = fmt::format("bytes {}-{}/{}", Range.Start, RangeEnd, TotalSize);
+ WriteResponse(HttpResponseCode::PartialContent, ContentType, std::move(RangeBuf));
+ return;
+ }
+
+ // Multi-range
+ MultipartByteRangesResult MultipartResult = BuildMultipartByteRanges(Data, Ranges);
+ if (MultipartResult.Parts.empty())
+ {
+ m_ContentRangeHeader = fmt::format("bytes */{}", Data.GetSize());
+ WriteResponse(HttpResponseCode::RangeNotSatisfiable);
+ return;
+ }
+ WriteResponse(HttpResponseCode::PartialContent, std::move(MultipartResult.ContentType), std::span<IoBuffer>(MultipartResult.Parts));
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(ParseContentType(CustomContentType) == HttpContentType::kUnknownContentType);
+ m_ContentTypeOverride = CustomContentType;
+ WriteResponse(ResponseCode, HttpContentType::kBinary, Blobs);
+}
+
+void
HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload)
{
std::span<const SharedBuffer> Segments = Payload.GetSegments();
@@ -705,7 +806,10 @@ HttpServerRequest::ReadPayloadPackage()
{
if (IoBuffer Payload = ReadPayload())
{
- return ParsePackageMessage(std::move(Payload));
+ ParseFlags Flags =
+ (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault;
+ const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr;
+ return ParsePackageMessage(std::move(Payload), {}, Flags, Policy);
}
return {};
@@ -816,7 +920,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
// Strip the separator slash left over after the service prefix is removed.
// When a service has BaseUri "/foo", the prefix length is set to len("/foo") = 4.
- // Stripping 4 chars from "/foo/bar" yields "/bar" — the path separator becomes
+ // Stripping 4 chars from "/foo/bar" yields "/bar" - the path separator becomes
// the first character of the relative URI. Remove it so patterns like "bar" or
// "{id}" match without needing to account for the leading slash.
if (!Uri.empty() && Uri.front() == '/')
@@ -1273,7 +1377,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP
return PackageHandlerRef->CreateTarget(Cid, Size);
};
- CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
+ ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences())
+ ? ParseFlags::kAllowLocalReferences
+ : ParseFlags::kDefault;
+ const ILocalRefPolicy* PkgPolicy =
+ EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr;
+ CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy);
PackageHandlerRef->OnRequestComplete();
}
@@ -1512,7 +1621,7 @@ TEST_CASE("http.common")
},
HttpVerb::kGet);
- // Single-segment literal with leading slash — simulates real server RelativeUri
+ // Single-segment literal with leading slash - simulates real server RelativeUri
{
Reset();
TestHttpServerRequest req{Service, "/activity_counters"sv};
@@ -1532,7 +1641,7 @@ TEST_CASE("http.common")
CHECK_EQ(Captures[0], "hello"sv);
}
- // Two-segment route with leading slash — first literal segment
+ // Two-segment route with leading slash - first literal segment
{
Reset();
TestHttpServerRequest req{Service, "/prefix/world"sv};
diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h
new file mode 100644
index 000000000..cb41626b9
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/asynchttpclient.h
@@ -0,0 +1,123 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zenhttp/httpclient.h>
+
+#include <functional>
+#include <future>
+#include <memory>
+
+namespace asio {
+class io_context;
+}
+
+namespace zen {
+
+/// Completion callback for async HTTP operations.
+using AsyncHttpCallback = std::function<void(HttpClient::Response)>;
+
+/** Asynchronous HTTP client backed by curl_multi and ASIO.
+ *
+ * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process
+ * transfers without blocking the caller. All curl_multi operations are
+ * serialized on an internal strand; callers may issue requests from any
+ * thread, and the io_context may have multiple threads.
+ *
+ * Two construction modes:
+ * - Owned io_context: creates an internal thread (self-contained).
+ * - External io_context: caller runs the event loop.
+ *
+ * Completion callbacks are dispatched on the io_context (not the internal
+ * strand), so a slow callback will not block the curl poll loop. Future-
+ * based wrappers (Get, Post, ...) return a std::future<Response> for
+ * callers that prefer blocking on a result.
+ */
+class AsyncHttpClient
+{
+public:
+ using Response = HttpClient::Response;
+ using KeyValueMap = HttpClient::KeyValueMap;
+
+ /// Construct with an internally-owned io_context and thread.
+ explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {});
+
+ /// Construct with an externally-managed io_context. The io_context must
+ /// outlive this client and must be running (via run()) on at least one thread.
+ AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {});
+
+ ~AsyncHttpClient();
+
+ AsyncHttpClient(const AsyncHttpClient&) = delete;
+ AsyncHttpClient& operator=(const AsyncHttpClient&) = delete;
+
+ // -- Callback-based API ----------------------------------------------
+
+ void AsyncGet(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPost(std::string_view Url,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPost(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {});
+
+ void AsyncPut(std::string_view Url,
+ const IoBuffer& Payload,
+ AsyncHttpCallback Callback,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {});
+
+ // -- Future-based API ------------------------------------------------
+
+ [[nodiscard]] std::future<Response> Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {});
+
+ [[nodiscard]] std::future<Response> Put(std::string_view Url,
+ const IoBuffer& Payload,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {});
+
+ [[nodiscard]] std::future<Response> Put(std::string_view Url, const KeyValueMap& Parameters = {});
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+void asynchttpclient_test_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h
index ce646ebd7..9220a50b6 100644
--- a/src/zenhttp/include/zenhttp/httpclientauth.h
+++ b/src/zenhttp/include/zenhttp/httpclientauth.h
@@ -33,7 +33,8 @@ namespace httpclientauth {
std::string_view CloudHost,
bool Quiet,
bool Unattended,
- bool Hidden);
+ bool Hidden,
+ bool IsHordeUrl = false);
} // namespace httpclientauth
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
index f9a99f3cc..1d921600d 100644
--- a/src/zenhttp/include/zenhttp/httpcommon.h
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -19,8 +19,8 @@ class StringBuilderBase;
struct HttpRange
{
- uint32_t Start = ~uint32_t(0);
- uint32_t End = ~uint32_t(0);
+ uint64_t Start = ~uint64_t(0);
+ uint64_t End = ~uint64_t(0);
};
using HttpRanges = std::vector<HttpRange>;
@@ -30,6 +30,16 @@ extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeSt
std::string_view ReasonStringForHttpResultCode(int HttpCode);
bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges);
+struct MultipartByteRangesResult
+{
+ std::vector<IoBuffer> Parts;
+ std::string ContentType;
+};
+
+// Build a multipart/byteranges response body from the given data and ranges.
+// Generates a unique boundary per call. Returns empty Parts if any range is out of bounds.
+MultipartByteRangesResult BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges);
+
enum class HttpVerb : uint8_t
{
kGet = 1 << 0,
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 5eaed6004..955b8ed15 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -12,6 +12,7 @@
#include <zencore/string.h>
#include <zencore/uid.h>
#include <zenhttp/httpcommon.h>
+#include <zenhttp/localrefpolicy.h>
#include <zentelemetry/hyperloglog.h>
#include <zentelemetry/stats.h>
@@ -121,11 +122,13 @@ public:
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0;
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload);
+ void WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs);
void WriteResponse(HttpResponseCode ResponseCode, CbObject Data);
void WriteResponse(HttpResponseCode ResponseCode, CbArray Array);
void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package);
void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString);
void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob);
+ void WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges);
virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0;
@@ -151,6 +154,8 @@ protected:
std::string_view m_QueryString;
mutable uint32_t m_RequestId = ~uint32_t(0);
mutable Oid m_SessionId = Oid::Zero;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
inline void SetIsHandled() { m_Flags |= kIsHandled; }
@@ -193,9 +198,16 @@ public:
HttpService() = default;
virtual ~HttpService() = default;
- virtual const char* BaseUri() const = 0;
- virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
- virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+ [[nodiscard]] virtual const char* BaseUri() const = 0;
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+
+ /// Whether this service accepts local file references in inbound packages from local clients.
+ [[nodiscard]] virtual bool AcceptsLocalFileReferences() const;
+
+ /// Returns the local ref policy for validating file paths in inbound local references.
+ /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed).
+ [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const;
// Internals
@@ -290,12 +302,12 @@ public:
std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; }
- /** Track active WebSocket connections — called by server implementations on upgrade/close. */
+ /** Track active WebSocket connections - called by server implementations on upgrade/close. */
void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); }
void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); }
uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); }
- /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */
+ /** Track WebSocket frame and byte counters - called by WS connection implementations per frame. */
void OnWebSocketFrameReceived(uint64_t Bytes)
{
m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed);
@@ -317,7 +329,7 @@ private:
int m_EffectiveHttpsPort = 0;
std::string m_ExternalHost;
metrics::Meter m_RequestMeter;
- metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error — sufficient for client counting
+ metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error - sufficient for client counting
metrics::HyperLogLog<12> m_ClientSessions;
std::string m_DefaultRedirect;
std::atomic<uint64_t> m_ActiveWebSocketConnections{0};
@@ -510,7 +522,8 @@ private:
bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-void http_forcelink(); // internal
-void websocket_forcelink(); // internal
+void http_forcelink(); // internal
+void httpparser_forcelink(); // internal
+void websocket_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h
index bce771c75..51ab2e06e 100644
--- a/src/zenhttp/include/zenhttp/httpstats.h
+++ b/src/zenhttp/include/zenhttp/httpstats.h
@@ -23,11 +23,11 @@ namespace zen {
class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler
{
public:
- /// Construct without an io_context — optionally uses a dedicated push thread
+ /// Construct without an io_context - optionally uses a dedicated push thread
/// for WebSocket stats broadcasting.
explicit HttpStatsService(bool EnableWebSockets = false);
- /// Construct with an external io_context — uses an asio timer instead
+ /// Construct with an external io_context - uses an asio timer instead
/// of a dedicated thread for WebSocket stats broadcasting.
/// The caller must ensure the io_context outlives this service and that
/// its run loop is active.
@@ -43,7 +43,7 @@ public:
virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override;
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override;
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override;
void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override;
void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override;
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
index 9c3b909a2..fd2f79171 100644
--- a/src/zenhttp/include/zenhttp/httpwsclient.h
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -26,7 +26,7 @@ namespace zen {
* Callback interface for WebSocket client events
*
* Separate from the server-side IWebSocketHandler because the caller
- * already owns the HttpWsClient — no Ref<WebSocketConnection> needed.
+ * already owns the HttpWsClient - no Ref<WebSocketConnection> needed.
*/
class IWsClientHandler
{
@@ -85,9 +85,9 @@ private:
/// it is treated as a plain host:port and gets the ws:// prefix.
///
/// Examples:
-/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws"
-/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo"
-/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar"
+/// HttpToWsUrl("http://host:8080", "/orch/ws") -> "ws://host:8080/orch/ws"
+/// HttpToWsUrl("https://host", "/foo") -> "wss://host/foo"
+/// HttpToWsUrl("host:8080", "/bar") -> "ws://host:8080/bar"
std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path);
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h
new file mode 100644
index 000000000..0b37f9dc7
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/localrefpolicy.h
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <filesystem>
+
+namespace zen {
+
+/// Policy interface for validating local file reference paths in inbound CbPackage messages.
+/// Implementations should throw std::invalid_argument if the path is not allowed.
+class ILocalRefPolicy
+{
+public:
+ virtual ~ILocalRefPolicy() = default;
+
+ /// Validate that a local file reference path is allowed.
+ /// Throws std::invalid_argument if the path escapes the allowed root.
+ virtual void ValidatePath(const std::filesystem::path& Path) const = 0;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h
index 1a5068580..66e3f6e55 100644
--- a/src/zenhttp/include/zenhttp/packageformat.h
+++ b/src/zenhttp/include/zenhttp/packageformat.h
@@ -5,6 +5,7 @@
#include <zencore/compactbinarypackage.h>
#include <zencore/iobuffer.h>
#include <zencore/iohash.h>
+#include <zenhttp/localrefpolicy.h>
#include <functional>
#include <gsl/gsl-lite.hpp>
@@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr);
-CbPackage ParsePackageMessage(
- IoBuffer Payload,
- std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+
+enum class ParseFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags);
+
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
return IoBuffer{Size};
- });
+ },
+ ParseFlags Flags = ParseFlags::kDefault,
+ const ILocalRefPolicy* Policy = nullptr);
bool IsPackageMessage(IoBuffer Payload);
bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
@@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe
class CbPackageReader
{
public:
- CbPackageReader();
+ CbPackageReader(ParseFlags Flags = ParseFlags::kDefault);
~CbPackageReader();
void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+ void SetLocalRefPolicy(const ILocalRefPolicy* Policy);
/** Process compact binary package data stream
@@ -149,6 +162,8 @@ private:
kReadingBuffers
} m_CurrentState = State::kInitialState;
+ ParseFlags m_Flags;
+ const ILocalRefPolicy* m_LocalRefPolicy = nullptr;
std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
std::vector<IoBuffer> m_PayloadBuffers;
std::vector<CbAttachmentEntry> m_AttachmentEntries;
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
index 710579faa..2d25515d3 100644
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -59,7 +59,7 @@ class IWebSocketHandler
public:
virtual ~IWebSocketHandler() = default;
- virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0;
+ virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0;
virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0;
virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0;
};
diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp
index 7e6207e56..5ad5ebcc7 100644
--- a/src/zenhttp/monitoring/httpstats.cpp
+++ b/src/zenhttp/monitoring/httpstats.cpp
@@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request)
//
void
-HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection)
+HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri)
{
+ ZEN_UNUSED(RelativeUri);
ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen");
ZEN_INFO("Stats WebSocket client connected");
diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp
index 9c62c1f2d..267ce386c 100644
--- a/src/zenhttp/packageformat.cpp
+++ b/src/zenhttp/packageformat.cpp
@@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:");
typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t;
+/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security.
+/// If no policy is set, file-path local refs are rejected (fail-closed).
+static void
+ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path)
+{
+ if (Policy)
+ {
+ Policy->ValidatePath(Path);
+ }
+ else
+ {
+ throw std::invalid_argument("local file reference rejected: no validation policy");
+ }
+}
+
+// Validates the CbPackageHeader magic and attachment count. Returns the total
+// chunk count (AttachmentCount + 1, including the implicit root object).
+static uint32_t
+ValidatePackageHeader(const CbPackageHeader& Hdr)
+{
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ throw std::invalid_argument(
+ fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic));
+ }
+ // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against
+ // UINT32_MAX wrapping to 0, which would bypass subsequent size checks.
+ if (Hdr.AttachmentCount == UINT32_MAX)
+ {
+ throw std::invalid_argument("invalid CbPackage, attachment count overflow");
+ }
+ return Hdr.AttachmentCount + 1;
+}
+
+struct ValidatedLocalRef
+{
+ bool Valid = false;
+ const CbAttachmentReferenceHeader* Header = nullptr;
+ std::string_view Path;
+ std::string Error;
+};
+
+// Validates that the attachment buffer contains a well-formed local reference
+// header and path. On failure, Valid is false and Error contains the reason.
+static ValidatedLocalRef
+ValidateLocalRef(const IoBuffer& AttachmentBuffer)
+{
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader))
+ {
+ return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())};
+ }
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+
+ if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)
+ {
+ return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})",
+ sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength,
+ AttachmentBuffer.Size())};
+ }
+
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)};
+}
+
IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle);
std::vector<IoBuffer>
@@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload)
}
CbPackage
-ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+ParsePackageMessage(IoBuffer Payload,
+ std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer,
+ ParseFlags Flags,
+ const ILocalRefPolicy* Policy)
{
ZEN_TRACE_CPU("ParsePackageMessage");
@@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
BinaryReader Reader(Payload);
- const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
- if (Hdr->HeaderMagic != kCbPkgMagic)
- {
- throw std::invalid_argument(
- fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic));
- }
+ const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData());
+ const uint32_t ChunkCount = ValidatePackageHeader(*Hdr);
Reader.Skip(sizeof(CbPackageHeader));
- const uint32_t ChunkCount = Hdr->AttachmentCount + 1;
-
- if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount)
+ // Widen to uint64_t so the multiplication cannot wrap on 32-bit.
+ const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount;
+ if (Reader.Remaining() < AttachmentTableSize)
{
throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)",
sizeof(CbAttachmentEntry) * ChunkCount,
@@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
{
- // Marshal local reference - a "pointer" to the chunk backing file
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+ if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i));
+ }
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+ // Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
- std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash)));
+ continue;
+ }
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::string_view PathView = LocalRef.Path;
IoBuffer FullFileBuffer;
@@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint
}
else
{
+ ApplyLocalRefPolicy(Policy, Path);
FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
}
}
if (FullFileBuffer)
{
- IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ // Guard against offset+size overflow or exceeding the file bounds.
+ const uint64_t FileSize = FullFileBuffer.GetSize();
+ if (AttachRefHdr->PayloadByteOffset > FileSize ||
+ AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset)
+ {
+ MalformedAttachments.push_back(
+ std::make_pair(i,
+ fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}",
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ FileSize,
+ Entry.AttachmentHash)));
+ continue;
+ }
+
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize
? FullFileBuffer
: IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa
return OutPackage.TryLoad(Response);
}
-CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+CbPackageReader::CbPackageReader(ParseFlags Flags)
+: m_Flags(Flags)
+, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
{
}
@@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci
m_CreateBuffer = CreateBuffer;
}
+void
+CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy)
+{
+ m_LocalRefPolicy = Policy;
+}
+
uint64_t
CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
{
@@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
return sizeof m_PackageHeader;
case State::kReadingHeader:
- ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
- memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
- ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
- m_CurrentState = State::kReadingAttachmentEntries;
- m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
- return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+ {
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(ChunkCount);
+ return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry);
+ }
case State::kReadingAttachmentEntries:
ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
@@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
{
// Marshal local reference - a "pointer" to the chunk backing file
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
-
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer);
+ if (!LocalRef.Valid)
+ {
+ throw std::invalid_argument(LocalRef.Error);
+ }
- std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+ const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header;
+ std::filesystem::path Path(Utf8ToWide(LocalRef.Path));
- std::filesystem::path Path{PathView};
+ if (!LocalRef.Path.starts_with(HandlePrefix))
+ {
+ ApplyLocalRefPolicy(m_LocalRefPolicy, Path);
+ }
IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
@@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
AttachRefHdr->PayloadByteSize));
}
+ // MakeFromFile silently clamps offset+size to the file size. Detect this
+ // to avoid returning a short buffer that could cause subtle downstream issues.
+ if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize)
+ {
+ throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize,
+ ChunkReference.GetSize()));
+ }
+
return ChunkReference;
};
@@ -732,6 +843,13 @@ CbPackageReader::Finalize()
{
IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+ if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences))
+ {
+ throw std::invalid_argument(
+ fmt::format("package contains local reference (attachment #{}) but local references are not allowed",
+ CurrentAttachmentIndex));
+ }
+
if (CurrentAttachmentIndex == 0)
{
// Root object
@@ -815,6 +933,13 @@ CbPackageReader::Finalize()
TEST_SUITE_BEGIN("http.packageformat");
+/// Permissive policy that allows any path, for use in tests that exercise local ref
+/// functionality but are not testing path validation.
+struct PermissiveLocalRefPolicy : public ILocalRefPolicy
+{
+ void ValidatePath(const std::filesystem::path&) const override {}
+};
+
TEST_CASE("CbPackage.Serialization")
{
// Make a test package
@@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef")
RemainingBytes -= ByteCount;
};
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackageReader Reader(ParseFlags::kAllowLocalReferences);
+ Reader.SetLocalRefPolicy(&AllowAllPolicy);
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedHeader")
+{
+ // Payload too small for a CbPackageHeader
+ uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa};
+ IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagic")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xDEADBEEF;
+ Hdr.AttachmentCount = 0;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflow")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable")
+{
+ // Valid header but not enough data for the attachment entries
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = 10;
+ IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr));
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.TruncatedAttachmentData")
+{
+ // Valid header + one attachment entry claiming more data than available
+ std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry));
+
+ CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data());
+ Hdr->HeaderMagic = kCbPkgMagic;
+ Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object)
+
+ CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader));
+ Entry->PayloadSize = 9999; // way more than available
+ Entry->Flags = CbAttachmentEntry::kIsObject;
+ Entry->AttachmentHash = IoHash();
+
+ IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size());
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault")
+{
+ // Build a valid package with local refs backed by compressed-format files,
+ // then verify it's rejected with default ParseFlags and accepted when allowed.
+ ScopedTemporaryDirectory TempDir;
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ // Compress data and write to disk, then create file-backed compressed attachments.
+ // The files must contain compressed-format data because ParsePackageMessage expects it
+ // when resolving local refs.
+ CompressedBuffer Comp1 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ CompressedBuffer Comp2 =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+
+ IoHash Hash1 = Comp1.DecodeRawHash();
+ IoHash Hash2 = Comp2.DecodeRawHash();
+
+ {
+ IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer();
+ IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(Path1, Buf1);
+ WriteFile(Path2, Buf2);
+ }
+
+ // Create attachments from file-backed buffers so FormatPackageMessage uses local refs
+ CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1};
+ CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Default flags should reject local refs
+ CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument);
+
+ // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip
+ // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader)
+ PermissiveLocalRefPolicy AllowAllPolicy;
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 2);
+}
+
+TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader")
+{
+ // Same test but via CbPackageReader
+ ScopedTemporaryDirectory TempDir;
+ auto FilePath = TempDir.Path() / "testdata";
+
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata"));
+ WriteFile(FilePath, Buf);
+ }
+
+ IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath);
+ CbAttachment Attach{SharedBuffer(FileBuffer)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ // Default flags should reject
CbPackageReader Reader;
uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
@@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef")
CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
}
- Reader.Finalize();
+ CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.BadMagicViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = 0xBADCAFE;
+ Hdr.AttachmentCount = 0;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader")
+{
+ CbPackageHeader Hdr{};
+ Hdr.HeaderMagic = kCbPkgMagic;
+ Hdr.AttachmentCount = UINT32_MAX;
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot")
+{
+ // A file outside the allowed root should be rejected by the policy
+ ScopedTemporaryDirectory AllowedRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "secret.dat";
+ {
+ IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret"));
+ WriteFile(OutsidePath, Buf);
+ }
+
+ // Create file-backed compressed attachment from outside root
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // Policy rooted at AllowedRoot should reject the file in OutsideDir
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot")
+{
+ // A file inside the allowed root should be accepted by the policy
+ ScopedTemporaryDirectory TempRoot;
+
+ auto FilePath = TempRoot.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy);
+ CHECK(Result.GetObject());
+ CHECK(Result.GetAttachments().size() == 1);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal")
+{
+ // A file path containing ".." that resolves outside root should be rejected
+ ScopedTemporaryDirectory TempRoot;
+ ScopedTemporaryDirectory OutsideDir;
+
+ auto OutsidePath = OutsideDir.Path() / "evil.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(OutsidePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ struct TestPolicy : public ILocalRefPolicy
+ {
+ std::string Root;
+ void ValidatePath(const std::filesystem::path& Path) const override
+ {
+ std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string();
+ if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0)
+ {
+ throw std::invalid_argument("path outside root");
+ }
+ }
+ } Policy;
+ // Root is TempRoot, but the file lives in OutsideDir
+ Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string();
+
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument);
+}
+
+TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed")
+{
+ // When local refs are allowed but no policy is provided, file-path refs should be rejected
+ ScopedTemporaryDirectory TempDir;
+
+ auto FilePath = TempDir.Path() / "data.dat";
+
+ CompressedBuffer Comp =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ IoHash Hash = Comp.DecodeRawHash();
+ {
+ IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer();
+ WriteFile(FilePath, Buf);
+ }
+
+ CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("data", Attach);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach);
+ Pkg.SetObject(Cbo.Save());
+
+ IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer();
+
+ // kAllowLocalReferences but nullptr policy => fail-closed
+ CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument);
}
TEST_SUITE_END();
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 7972777b8..a1a775ba3 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -625,6 +625,8 @@ public:
void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
/**
* Initialize the response for sending a payload made up of multiple blobs
@@ -768,10 +770,18 @@ public:
{
ZEN_MEMSCOPE(GetHttpasioTag());
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -898,7 +908,9 @@ private:
bool m_AllowZeroCopyFileSend = true;
State m_State = State::kUninitialized;
HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
asio::buffer(ResponseStr->data(), ResponseStr->size()),
asio::bind_executor(
m_Strand,
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()](
+ const asio::error_code& Ec,
+ std::size_t) {
if (Ec)
{
ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
@@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest()
Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ std::string_view FullUrl = Conn->m_RequestData.Url();
+ std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
}));
@@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest()
return;
}
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
if (!m_RequestData.IsKeepAlive())
@@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->SetKeepAlive(m_Request.IsKeepAlive());
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 918b55dc6..8b07c7905 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -8,6 +8,13 @@
#include <limits>
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <cstring>
+# include <string>
+# include <string_view>
+#endif
+
namespace zen {
using namespace std::literals;
@@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W
// HttpRequestParser
//
-http_parser_settings HttpRequestParser::s_ParserSettings{
- .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
- .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
- .on_status =
- [](http_parser* p, const char* Data, size_t ByteCount) {
- ZEN_UNUSED(p, Data, ByteCount);
- return 0;
- },
- .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
- .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
+// clang-format off
+llhttp_settings_t HttpRequestParser::s_ParserSettings = []() {
+ llhttp_settings_t S;
+ llhttp_settings_init(&S);
+ S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); };
+ S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); };
+ S.on_status = [](llhttp_t*, const char*, size_t) { return 0; };
+ S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); };
+ S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); };
+ S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); };
+ S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); };
+ S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); };
+ return S;
+}();
+// clang-format on
HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection)
{
- http_parser_init(&m_Parser, HTTP_REQUEST);
+ llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings);
m_Parser.data = this;
ResetState();
@@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser()
size_t
HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize)
{
- const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
-
- http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
-
- if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize);
+ if (Err == HPE_OK)
{
- ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
- return ~0ull;
+ return DataSize;
}
- return ConsumedBytes;
+ if (Err == HPE_PAUSED_UPGRADE)
+ {
+ return DataSize;
+ }
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser));
+ return ~0ull;
}
int
@@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_UrlRange.Length == 0)
@@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
if (m_HeaderEntries.empty())
@@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes)
if (RemainingBufferSpace < Bytes)
{
ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
+ return -1;
}
ZEN_ASSERT_SLOW(!m_HeaderEntries.empty());
@@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete()
}
}
- m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+ m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser);
- switch (m_Parser.method)
+ switch (llhttp_get_method(&m_Parser))
{
case HTTP_GET:
m_RequestVerb = HttpVerb::kGet;
@@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete()
break;
default:
- ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser))));
break;
}
@@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes)
{
ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes",
(m_BodyPosition + Bytes) - m_BodyBuffer.Size());
- return 1;
+ return -1;
}
memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
m_BodyPosition += Bytes;
- if (http_body_is_final(&m_Parser))
- {
- if (m_BodyPosition != m_BodyBuffer.Size())
- {
- ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
- return 1;
- }
- }
-
return 0;
}
@@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete()
catch (const AssertException& AssertEx)
{
ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription());
- return 1;
+ return -1;
}
catch (const std::system_error& SystemError)
{
@@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete()
ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value());
}
ResetState();
- return 1;
+ return -1;
}
catch (const std::bad_alloc& BadAlloc)
{
ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what());
ResetState();
- return 1;
+ return -1;
}
catch (const std::exception& Ex)
{
ZEN_ERROR("failed processing http request: '{}'", Ex.what());
ResetState();
- return 1;
+ return -1;
}
}
@@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const
return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0;
}
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace {
+
+ struct MockCallbacks : HttpRequestParserCallbacks
+ {
+ int HandleRequestCount = 0;
+ int TerminateCount = 0;
+
+ HttpRequestParser* Parser = nullptr;
+
+ HttpVerb LastVerb{};
+ std::string LastUrl;
+ std::string LastQueryString;
+ std::string LastBody;
+ bool LastKeepAlive = false;
+ bool LastIsWebSocketUpgrade = false;
+ std::string LastSecWebSocketKey;
+ std::string LastUpgradeHeader;
+ HttpContentType LastContentType{};
+
+ void HandleRequest() override
+ {
+ ++HandleRequestCount;
+ if (Parser)
+ {
+ LastVerb = Parser->RequestVerb();
+ LastUrl = std::string(Parser->Url());
+ LastQueryString = std::string(Parser->QueryString());
+ LastKeepAlive = Parser->IsKeepAlive();
+ LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade();
+ LastSecWebSocketKey = std::string(Parser->SecWebSocketKey());
+ LastUpgradeHeader = std::string(Parser->UpgradeHeader());
+ LastContentType = Parser->ContentType();
+
+ IoBuffer Body = Parser->Body();
+ if (Body.Size() > 0)
+ {
+ LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size());
+ }
+ else
+ {
+ LastBody.clear();
+ }
+ }
+ }
+
+ void TerminateConnection() override { ++TerminateCount; }
+ };
+
+} // anonymous namespace
+
+TEST_SUITE_BEGIN("http.httpparser");
+
+TEST_CASE("httpparser.basic_get")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kGet);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK(Mock.LastKeepAlive);
+}
+
+TEST_CASE("httpparser.post_with_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 13\r\n"
+ "Content-Type: application/json\r\n"
+ "\r\n"
+ "{\"key\":\"val\"}";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, HttpVerb::kPost);
+ CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}");
+ CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON);
+}
+
+TEST_CASE("httpparser.pipelined_requests")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n"
+ "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 2);
+ CHECK_EQ(Mock.LastUrl, "/second");
+}
+
+TEST_CASE("httpparser.partial_header")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc";
+ std::string Chunk2 = "alhost\r\n\r\n";
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, Chunk2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+}
+
+TEST_CASE("httpparser.partial_body")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Headers =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 10\r\n"
+ "\r\n";
+ std::string BodyPart1 = "hello";
+ std::string BodyPart2 = "world";
+
+ std::string Chunk1 = Headers + BodyPart1;
+
+ size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size());
+ CHECK_NE(Consumed1, ~0ull);
+ CHECK_EQ(Consumed1, Chunk1.size());
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+
+ size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size());
+ CHECK_NE(Consumed2, ~0ull);
+ CHECK_EQ(Consumed2, BodyPart2.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "helloworld");
+}
+
+TEST_CASE("httpparser.invalid_request")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Garbage = "NOT_HTTP garbage data\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 0);
+}
+
+TEST_CASE("httpparser.body_overflow")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes,
+ // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY"
+ // as a new HTTP request which fails.
+ std::string Request =
+ "POST /api HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Content-Length: 2\r\n"
+ "\r\n"
+ "TOO_LONG_BODY";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastBody, "TO");
+}
+
+TEST_CASE("httpparser.websocket_upgrade")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+ CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ==");
+ CHECK_EQ(Mock.LastUpgradeHeader, "websocket");
+}
+
+TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string HttpPart =
+ "GET /ws HTTP/1.1\r\n"
+ "Host: localhost\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+ "Sec-WebSocket-Version: 13\r\n"
+ "\r\n";
+
+ // Append fake WebSocket frame bytes after the HTTP message
+ std::string Request = HttpPart;
+ Request.push_back('\x81');
+ Request.push_back('\x05');
+ Request.append("hello");
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_NE(Consumed, ~0ull);
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK(Mock.LastIsWebSocketUpgrade);
+}
+
+TEST_CASE("httpparser.keep_alive_detection")
+{
+ SUBCASE("HTTP/1.1 default keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK(Mock.LastKeepAlive);
+ }
+
+ SUBCASE("Connection: close disables keep-alive")
+ {
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
+ Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_FALSE(Mock.LastKeepAlive);
+ }
+}
+
+TEST_CASE("httpparser.all_verbs")
+{
+ struct VerbTest
+ {
+ const char* Method;
+ HttpVerb Expected;
+ };
+
+ VerbTest Tests[] = {
+ {"GET", HttpVerb::kGet},
+ {"POST", HttpVerb::kPost},
+ {"PUT", HttpVerb::kPut},
+ {"DELETE", HttpVerb::kDelete},
+ {"HEAD", HttpVerb::kHead},
+ {"COPY", HttpVerb::kCopy},
+ {"OPTIONS", HttpVerb::kOptions},
+ };
+
+ for (const VerbTest& Test : Tests)
+ {
+ CAPTURE(Test.Method);
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n";
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastVerb, Test.Expected);
+ }
+}
+
+TEST_CASE("httpparser.query_string")
+{
+ MockCallbacks Mock;
+ HttpRequestParser Parser(Mock);
+ Mock.Parser = &Parser;
+
+ std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n";
+
+ size_t Consumed = Parser.ConsumeData(Request.data(), Request.size());
+ CHECK_EQ(Consumed, Request.size());
+ CHECK_EQ(Mock.HandleRequestCount, 1);
+ CHECK_EQ(Mock.LastUrl, "/path");
+ CHECK_EQ(Mock.LastQueryString, "key=val&other=123");
+}
+
+TEST_SUITE_END();
+
+void
+httpparser_forcelink()
+{
+}
+
+#endif // ZEN_WITH_TESTS
+
} // namespace zen
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 23ad9d8fb..4ff216248 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -8,7 +8,7 @@
#include <EASTL/fixed_vector.h>
ZEN_THIRD_PARTY_INCLUDES_START
-#include <http_parser.h>
+#include <llhttp.h>
ZEN_THIRD_PARTY_INCLUDES_END
#include <atomic>
@@ -100,7 +100,7 @@ private:
Oid m_SessionId{};
IoBuffer m_BodyBuffer;
uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
+ llhttp_t m_Parser;
eastl::fixed_vector<char, 512> m_HeaderData;
std::string m_NormalizedUrl;
@@ -114,8 +114,8 @@ private:
int OnBody(const char* Data, size_t Bytes);
int OnMessageComplete();
- static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
- static http_parser_settings s_ParserSettings;
+ static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); }
+ static llhttp_settings_t s_ParserSettings;
};
} // namespace zen
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 31b0315d4..b0fb020e0 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -185,13 +185,17 @@ public:
const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; }
void SuppressPayload() { m_ResponseBuffers.resize(1); }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
std::string_view GetHeaders();
private:
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
uint64_t m_ContentLength = 0;
std::vector<IoBuffer> m_ResponseBuffers;
ExtendableStringBuilder<160> m_Headers;
@@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders()
if (m_Headers.Size() == 0)
{
+ std::string_view ContentTypeStr =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
+
m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Type: " << ContentTypeStr << "\r\n"
<< "Content-Length: " << ContentLength() << "\r\n"sv;
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv;
+ }
+
if (!m_IsKeepAlive)
{
m_Headers << "Connection: close\r\n"sv;
@@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary));
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten
ZEN_MEMSCOPE(GetHttppluginTag());
m_Response.reset(new HttpPluginResponse(ContentType));
+ if (!m_ContentTypeOverride.empty())
+ {
+ m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 2cad97725..67b1230a0 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -464,6 +464,8 @@ public:
inline int64_t GetResponseBodySize() const { return m_TotalDataSize; }
void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; }
+ void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); }
+ void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); }
private:
eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks;
@@ -473,6 +475,8 @@ private:
uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
bool m_IsInitialResponse = true;
HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::string m_ContentTypeOverride;
+ std::string m_ContentRangeHeader;
eastl::fixed_vector<IoBuffer, 16> m_DataBuffers;
std::string m_LocationHeader;
@@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
- std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+ std::string_view ContentTypeString =
+ m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride);
ContentTypeHeader->pRawValue = ContentTypeString.data();
ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
@@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size();
}
+ // Content-Range header (for 206 Partial Content single-range responses)
+
+ if (!m_ContentRangeHeader.empty())
+ {
+ PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange];
+ ContentRangeHeader->pRawValue = m_ContentRangeHeader.data();
+ ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size();
+ }
+
std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
HttpResponse.StatusCode = m_ResponseCode;
@@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort)
else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED)
{
// Port may be owned by another process's wildcard registration (access denied)
- // or actively in use (sharing violation) — retry on a different port
+ // or actively in use (sharing violation) - retry on a different port
ShouldRetryNextPort = true;
}
else
@@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy
HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+ if (!m_ContentTypeOverride.empty())
+ {
+ Response->SetContentTypeOverride(std::move(m_ContentTypeOverride));
+ }
+ if (!m_ContentRangeHeader.empty())
+ {
+ Response->SetContentRangeHeader(std::move(m_ContentRangeHeader));
+ }
+
if (SuppressBody())
{
Response->SuppressResponseBody();
@@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
&Transaction().Server()));
Ref<WebSocketConnection> WsConnRef(WsConn.Get());
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ ExtendableStringBuilder<128> UrlUtf8;
+ WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath,
+ gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))},
+ UrlUtf8);
+ int PrefixLen = Service->UriPrefixLength();
+ std::string_view RelativeUri{UrlUtf8.ToView()};
+ RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size())));
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri);
WsConn->Start();
return nullptr;
@@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT
ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult);
- // WebSocket upgrade failed — return nullptr since ServerRequest()
+ // WebSocket upgrade failed - return nullptr since ServerRequest()
// was never populated (no InvokeRequestHandler call)
return nullptr;
}
- // Service doesn't support WebSocket or missing key — fall through to normal handling
+ // Service doesn't support WebSocket or missing key - fall through to normal handling
}
}
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index 5ae48f5b3..078c21ea1 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
index af320172d..8520e9f60 100644
--- a/src/zenhttp/servers/wshttpsys.cpp
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown()
return;
}
- // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED
+ // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
@@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData()
}
case WebSocketOpcode::kPong:
- // Unsolicited pong — ignore per RFC 6455
+ // Unsolicited pong - ignore per RFC 6455
break;
case WebSocketOpcode::kClose:
@@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
- // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED
+ // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED
HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr);
}
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 59c46a418..a58037fec 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -5,6 +5,7 @@
# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zencore/timer.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/httpwsclient.h>
@@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec")
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
- // Server frames are unmasked — TryParseFrame should handle them
+ // Server frames are unmasked - TryParseFrame should handle them
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size());
CHECK(Result.IsValid);
@@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec")
{
std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{});
- // Pass only 1 byte — not enough for a frame header
+ // Pass only 1 byte - not enough for a frame header
WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1);
CHECK_FALSE(Result.IsValid);
CHECK_EQ(Result.BytesConsumed, 0u);
@@ -335,8 +336,9 @@ namespace {
}
// IWebSocketHandler
- void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override
+ void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override
{
+ ZEN_UNUSED(RelativeUri);
m_OpenCount.fetch_add(1);
m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); });
@@ -463,7 +465,7 @@ namespace {
if (!Done.load())
{
- // Timeout — cancel the read
+ // Timeout - cancel the read
asio::error_code Ec;
Sock.cancel(Ec);
}
@@ -476,6 +478,23 @@ namespace {
return Result;
}
+ static void WaitForServerListening(int Port)
+ {
+ Stopwatch Timer;
+ while (Timer.GetElapsedTimeMs() < 5'000)
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Probe(IoCtx);
+ asio::error_code Ec;
+ Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec);
+ if (!Ec)
+ {
+ return;
+ }
+ Sleep(10);
+ }
+ }
+
} // anonymous namespace
TEST_CASE("websocket.integration")
@@ -501,8 +520,8 @@ TEST_CASE("websocket.integration")
Server->Close();
});
- // Give server a moment to start accepting
- Sleep(100);
+ // Wait for server to start accepting
+ WaitForServerListening(Port);
SUBCASE("handshake succeeds with 101")
{
@@ -692,7 +711,7 @@ TEST_CASE("websocket.integration")
std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data()));
- // Should NOT get 101 — should fall through to normal request handling
+ // Should NOT get 101 - should fall through to normal request handling
CHECK(Response.find("101") == std::string::npos);
Sock.close();
@@ -813,7 +832,7 @@ TEST_CASE("websocket.client")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close")
{
@@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket")
Server->Close();
});
- Sleep(100);
+ WaitForServerListening(Port);
SUBCASE("connect, echo, close over unix socket")
{
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index 7b050ae35..67a01403d 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -9,7 +9,7 @@ target('zenhttp')
add_files("servers/wshttpsys.cpp", {unity_ignored=true})
add_includedirs("include", {public=true})
add_deps("zencore", "zentelemetry", "transport-sdk", "asio")
- add_packages("http_parser", "json11", "libcurl")
+ add_packages("llhttp", "json11", "libcurl")
add_options("httpsys")
if is_plat("linux", "macosx") then
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index 3ac8eea8d..e15aa4d30 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -4,6 +4,7 @@
#if ZEN_WITH_TESTS
+# include <zenhttp/asynchttpclient.h>
# include <zenhttp/httpclient.h>
# include <zenhttp/httpserver.h>
# include <zenhttp/packageformat.h>
@@ -16,7 +17,9 @@ zenhttp_forcelinktests()
{
http_forcelink();
httpclient_forcelink();
+ httpparser_forcelink();
httpclient_test_forcelink();
+ asynchttpclient_test_forcelink();
forcelink_packageformat();
passwordsecurity_forcelink();
websocket_forcelink();