diff options
Diffstat (limited to 'src/zenhttp')
34 files changed, 2998 insertions, 415 deletions
diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp new file mode 100644 index 000000000..151863370 --- /dev/null +++ b/src/zenhttp/asynchttpclient_test.cpp @@ -0,0 +1,315 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Reusable test service for async client tests + +class AsyncHttpClientTestService : public HttpService +{ +public: + AsyncHttpClientTestService() + { + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}"); + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/async-test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + +private: + HttpRequestRouter m_Router; +}; + +////////////////////////////////////////////////////////////////////////// + +struct AsyncTestServerFixture +{ + AsyncHttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + AsyncTestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(0, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~AsyncTestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + AsyncHttpClient MakeClient(HttpClientSettings Settings = {}) { return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), Settings); } + + AsyncHttpClient MakeClient(asio::io_context& IoContext, HttpClientSettings Settings = {}) + { + return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), IoContext, Settings); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.asynchttpclient"); + +TEST_CASE("asynchttpclient.future.verbs") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + auto Future = Client.Get("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + auto Future = Client.Post("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + auto Future = Client.Put("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + auto Future = Client.Delete("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + auto Future = Client.Head("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("asynchttpclient.future.get") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET returning JSON") + { + auto Future = Client.Get("/api/async-test/json"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"ok\":true}"); + } + + SUBCASE("GET 204 NoContent") + { + auto Future = Client.Get("/api/async-test/nocontent"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("asynchttpclient.future.post.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PayloadStr = "async payload data"; + IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Post("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "async payload data"); +} + +TEST_CASE("asynchttpclient.future.put.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PutStr = "put payload"; + IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Put("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload"); +} + +TEST_CASE("asynchttpclient.callback") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::promise<HttpClient::Response> Promise; + auto Future = Promise.get_future(); + + Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); +} + +TEST_CASE("asynchttpclient.concurrent.requests") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + // Fire multiple requests concurrently + auto Future1 = Client.Get("/api/async-test/hello"); + auto Future2 = Client.Get("/api/async-test/json"); + auto Future3 = Client.Post("/api/async-test/echo/method"); + auto Future4 = Client.Delete("/api/async-test/echo/method"); + + auto Resp1 = Future1.get(); + auto Resp2 = Future2.get(); + auto Resp3 = Future3.get(); + auto Resp4 = Future4.get(); + + CHECK(Resp1.IsSuccess()); + CHECK_EQ(Resp1.AsText(), "hello world"); + + CHECK(Resp2.IsSuccess()); + CHECK_EQ(Resp2.AsText(), "{\"ok\":true}"); + + CHECK(Resp3.IsSuccess()); + CHECK_EQ(Resp3.AsText(), "POST"); + + CHECK(Resp4.IsSuccess()); + CHECK_EQ(Resp4.AsText(), "DELETE"); +} + +TEST_CASE("asynchttpclient.external.io_context") +{ + AsyncTestServerFixture Fixture; + + asio::io_context IoContext; + auto WorkGuard = asio::make_work_guard(IoContext); + std::thread IoThread([&IoContext]() { IoContext.run(); }); + + { + AsyncHttpClient Client = Fixture.MakeClient(IoContext); + + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + WorkGuard.reset(); + IoThread.join(); +} + +TEST_CASE("asynchttpclient.connection.error") +{ + // Connect to a port where nothing is listening + AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + + auto Future = Client.Get("/should-fail"); + auto Resp = Future.get(); + + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(Resp.Error->IsConnectionError()); +} + +TEST_SUITE_END(); + +void +asynchttpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index 209276621..2fa22f2c2 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -132,7 +132,7 @@ public: } } - RefPtr<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); + Ref<OidcClient> Client(new OidcClient(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId})); if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) { @@ -232,10 +232,10 @@ public: private: struct OpenIdProvider { - std::string Name; - std::string Url; - std::string ClientId; - RefPtr<OidcClient> HttpClient; + std::string Name; + std::string Url; + std::string ClientId; + Ref<OidcClient> HttpClient; }; struct OpenIdToken @@ -262,7 +262,7 @@ private: { ZEN_TRACE_CPU("AuthMgr::RefreshOpenIdToken"); - RefPtr<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; + Ref<OidcClient> Client = GetOpenIdProvider(ProviderName).HttpClient; if (!Client) { return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp new file mode 100644 index 000000000..ea88fc783 --- /dev/null +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -0,0 +1,1033 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> + +#include "httpclientcurlhelpers.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <thread> +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// TransferContext: per-transfer state associated with each CURL easy handle + +struct TransferContext +{ + AsyncHttpCallback Callback; + std::string Body; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + CurlWriteCallbackData WriteData; + CurlHeaderCallbackData HeaderData; + curl_slist* HeaderList = nullptr; + + // For PUT/POST with payload: keep the data alive until transfer completes + IoBuffer PayloadBuffer; + CurlReadCallbackData ReadData; + + TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) + { + WriteData.Body = &Body; + HeaderData.Headers = &ResponseHeaders; + } + + ~TransferContext() + { + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + } + + TransferContext(const TransferContext&) = delete; + TransferContext& operator=(const TransferContext&) = delete; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient::Impl + +struct AsyncHttpClient::Impl +{ + Impl(std::string_view BaseUri, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + m_WorkGuard.emplace(m_IoContext.get_executor()); + m_IoThread = std::thread([this]() { + SetCurrentThreadName("async_http"); + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what()); + } + }); + } + + Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + } + + ~Impl() + { + // Clean up curl state on the strand where all curl_multi operations + // are serialized. Use a promise to block until the cleanup handler + // has actually executed - essential for the external io_context case + // where we don't own the run loop. + std::promise<void> Done; + std::future<void> DoneFuture = Done.get_future(); + + asio::post(m_Strand, [this, &Done]() { + m_ShuttingDown = true; + m_Timer.cancel(); + + // Release all tracked sockets (don't close - curl owns the fds). + for (auto& [Fd, Info] : m_Sockets) + { + if (Info->Socket.is_open()) + { + Info->Socket.cancel(); + Info->Socket.release(); + } + } + m_Sockets.clear(); + + for (auto& [Handle, Ctx] : m_Transfers) + { + curl_multi_remove_handle(m_Multi, Handle); + curl_easy_cleanup(Handle); + } + m_Transfers.clear(); + + for (CURL* Handle : m_HandlePool) + { + curl_easy_cleanup(Handle); + } + m_HandlePool.clear(); + + Done.set_value(); + }); + + // For owned io_context: release work guard so run() can return after + // processing the cleanup handler above. + m_WorkGuard.reset(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + else + { + // External io_context: wait for the cleanup handler to complete. + DoneFuture.wait(); + } + + if (m_Multi) + { + curl_multi_cleanup(m_Multi); + } + } + + LoggerRef Log() { return m_Log; } + + void Init() + { + m_Multi = curl_multi_init(); + if (!m_Multi) + { + throw std::runtime_error("curl_multi_init failed"); + } + + SetupMultiCallbacks(); + + if (m_Settings.SessionId == Oid::Zero) + { + m_SessionId = std::string(GetSessionIdString()); + } + else + { + m_SessionId = m_Settings.SessionId.ToString(); + } + } + + // -- Handle pool ----------------------------------------------------- + + CURL* AllocHandle() + { + if (!m_HandlePool.empty()) + { + CURL* Handle = m_HandlePool.back(); + m_HandlePool.pop_back(); + curl_easy_reset(Handle); + return Handle; + } + CURL* Handle = curl_easy_init(); + if (!Handle) + { + throw std::runtime_error("curl_easy_init failed"); + } + return Handle; + } + + void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); } + + // -- Configure a handle with common settings ------------------------- + // Called only from DoAsync* lambdas running on the strand. + + void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) + { + // Build URL + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Unix domain socket + if (!m_Settings.UnixSocketPath.empty()) + { + m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str()); + } + + // Timeouts + if (m_Settings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_Settings.ConnectTimeout.count())); + } + if (m_Settings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count())); + } + + // HTTP/2 + if (m_Settings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // SSL + if (m_Settings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!m_Settings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str()); + } + + // Verbose/debug + if (m_Settings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + } + + // Thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (m_Settings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + } + + // -- Access token ---------------------------------------------------- + + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()(); + if (!NewToken.IsValid()) + { + ZEN_WARN("AsyncHttpClient: failed to refresh access token, retrying once"); + NewToken = m_Settings.AccessTokenProvider.value()(); + } + if (NewToken.IsValid()) + { + m_CachedAccessToken = NewToken; + return m_CachedAccessToken.GetValue(); + } + ZEN_WARN("AsyncHttpClient: access token provider returned invalid token"); + return {}; + } + + // -- Submit a transfer ----------------------------------------------- + + void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer"); + // Setup write/header callbacks + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData); + + m_Transfers[Handle] = std::move(Ctx); + + CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle); + if (Mc != CURLM_OK) + { + auto Stolen = std::move(m_Transfers[Handle]); + m_Transfers.erase(Handle); + ReleaseHandle(Handle); + + HttpClient::Response ErrorResponse; + ErrorResponse.Error = + HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))}; + asio::post(m_IoContext, + [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); + return; + } + } + + // -- Socket-action integration --------------------------------------- + // + // curl_multi drives I/O via two callbacks: + // - SocketCallback: curl tells us which sockets to watch for read/write + // - TimerCallback: curl tells us when to fire a timeout + // + // On each socket event or timeout we call curl_multi_socket_action(), + // then drain completed transfers via curl_multi_info_read(). + + // Per-socket state: wraps the native fd in an ASIO socket for async_wait. + struct SocketInfo + { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + + explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} + }; + + // Static thunks registered with curl_multi ---------------------------- + + static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) + { + ZEN_UNUSED(Easy); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + return 0; + } + + static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr) + { + ZEN_UNUSED(Multi); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlTimer(TimeoutMs); + return 0; + } + + void SetupMultiCallbacks() + { + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this); + } + + // Called by curl when socket watch state changes --------------------- + + void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) + { + if (Action == CURL_POLL_REMOVE) + { + if (Info) + { + // Cancel pending async_wait ops before releasing the fd. + // curl owns the fd, so we must release() rather than close(). + Info->Socket.cancel(); + if (Info->Socket.is_open()) + { + Info->Socket.release(); + } + m_Sockets.erase(Fd); + } + return; + } + + if (!Info) + { + // New socket - wrap the native fd in an ASIO socket. + auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); + Info = It->second.get(); + + asio::error_code Ec; + // Determine protocol from the fd (v4 vs v6). Default to v4. + Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + // Try v6 as fallback + Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) + { + ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + m_Sockets.erase(Fd); + return; + } + + curl_multi_assign(m_Multi, Fd, Info); + } + + Info->WatchFlags = Action; + SetSocketWatch(Fd, Info); + } + + void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + { + // Cancel any pending wait before issuing a new one. + Info->Socket.cancel(); + + if (Info->WatchFlags & CURL_POLL_IN) + { + Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + })); + } + + if (Info->WatchFlags & CURL_POLL_OUT) + { + Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + })); + } + } + + void OnSocketReady(curl_socket_t Fd, int CurlAction) + { + ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); + CheckCompleted(); + + // Re-arm the watch if the socket is still tracked. + auto It = m_Sockets.find(Fd); + if (It != m_Sockets.end()) + { + SetSocketWatch(Fd, It->second.get()); + } + } + + // Called by curl when it wants a timeout ------------------------------ + + void OnCurlTimer(long TimeoutMs) + { + m_Timer.cancel(); + + if (TimeoutMs < 0) + { + // curl says "no timeout needed" + return; + } + + if (TimeoutMs == 0) + { + // curl wants immediate action - run it directly on the strand. + asio::post(m_Strand, [this]() { + if (m_ShuttingDown) + { + return; + } + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + }); + return; + } + + m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); + m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + })); + } + + // Drain completed transfers from curl_multi -------------------------- + + void CheckCompleted() + { + int MsgsLeft = 0; + CURLMsg* Msg = nullptr; + while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr) + { + if (Msg->msg != CURLMSG_DONE) + { + continue; + } + + CURL* Handle = Msg->easy_handle; + CURLcode Result = Msg->data.result; + + curl_multi_remove_handle(m_Multi, Handle); + + auto It = m_Transfers.find(Handle); + if (It == m_Transfers.end()) + { + ReleaseHandle(Handle); + continue; + } + + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + + CompleteTransfer(Handle, Result, std::move(Ctx)); + } + } + + void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer"); + // Extract result info + long StatusCode = 0; + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + + ReleaseHandle(Handle); + + // Build response + HttpClient::Response Response; + Response.StatusCode = HttpResponseCode(StatusCode); + Response.UploadedBytes = static_cast<int64_t>(UpBytes); + Response.DownloadedBytes = static_cast<int64_t>(DownBytes); + Response.ElapsedSeconds = Elapsed; + Response.Header = BuildHeaderMap(Ctx->ResponseHeaders); + + if (CurlResult != CURLE_OK) + { + const char* ErrorMsg = curl_easy_strerror(CurlResult); + + if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg); + } + + if (!Ctx->Body.empty()) + { + Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + } + + Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)}; + } + else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty()) + { + // No payload + } + else + { + IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders); + + const HttpResponseCode Code = HttpResponseCode(StatusCode); + if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound) + { + ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri); + } + + Response.ResponsePayload = std::move(PayloadBuffer); + } + + // Dispatch the user callback off the strand so a slow callback + // cannot starve the curl_multi poll loop. + asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable { + try + { + Cb(std::move(Response)); + } + catch (const std::exception& Ex) + { + ZEN_SCOPED_LOG(LogRef); + ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what()); + } + }); + } + + // -- Async verb implementations -------------------------------------- + + void DoAsyncGet(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Get"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Head"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Delete"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPost(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Post"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPostWithPayload(std::string Url, + IoBuffer Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + ContentType, + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = + BuildHeaderList(AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + // Set up read callback for payload data + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutWithPayload(std::string Url, + IoBuffer Payload, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = BuildHeaderList( + AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + + HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + // -- Members --------------------------------------------------------- + + std::string m_BaseUri; + HttpClientSettings m_Settings; + LoggerRef m_Log; + std::string m_SessionId; + std::string m_UnixSocketPathUtf8; + + // io_context and strand - all curl_multi operations are serialized on the + // strand, making this safe even when the io_context has multiple threads. + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + asio::strand<asio::io_context::executor_type> m_Strand; + std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; + std::thread m_IoThread; + + // curl_multi and socket-action state + CURLM* m_Multi = nullptr; + std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; + asio::steady_timer m_Timer; + bool m_ShuttingDown = false; + + // Access token cache + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient public API + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, Settings)) +{ +} + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings)) +{ +} + +AsyncHttpClient::~AsyncHttpClient() = default; + +// -- Callback-based API -------------------------------------------------- + +void +AsyncHttpClient::AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters); +} + +// -- Future-based API ---------------------------------------------------- + +std::future<HttpClient::Response> +AsyncHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncGet( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncHead( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncDelete( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + ContentType, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + Parameters); + return Future; +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index d150b44c6..eee80c269 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpclientcurl.h" +#include "httpclientcurlhelpers.h" #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -29,153 +30,7 @@ static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; ////////////////////////////////////////////////////////////////////////// -static HttpClientErrorCode -MapCurlError(CURLcode Code) -{ - switch (Code) - { - case CURLE_OK: - return HttpClientErrorCode::kOK; - case CURLE_COULDNT_CONNECT: - return HttpClientErrorCode::kConnectionFailure; - case CURLE_COULDNT_RESOLVE_HOST: - return HttpClientErrorCode::kHostResolutionFailure; - case CURLE_COULDNT_RESOLVE_PROXY: - return HttpClientErrorCode::kProxyResolutionFailure; - case CURLE_RECV_ERROR: - return HttpClientErrorCode::kNetworkReceiveError; - case CURLE_SEND_ERROR: - return HttpClientErrorCode::kNetworkSendFailure; - case CURLE_OPERATION_TIMEDOUT: - return HttpClientErrorCode::kOperationTimedOut; - case CURLE_SSL_CONNECT_ERROR: - return HttpClientErrorCode::kSSLConnectError; - case CURLE_SSL_CERTPROBLEM: - return HttpClientErrorCode::kSSLCertificateError; - case CURLE_PEER_FAILED_VERIFICATION: - return HttpClientErrorCode::kSSLCACertError; - case CURLE_SSL_CIPHER: - case CURLE_SSL_ENGINE_NOTFOUND: - case CURLE_SSL_ENGINE_SETFAILED: - return HttpClientErrorCode::kGenericSSLError; - case CURLE_ABORTED_BY_CALLBACK: - return HttpClientErrorCode::kRequestCancelled; - default: - return HttpClientErrorCode::kOtherError; - } -} - -////////////////////////////////////////////////////////////////////////// -// -// Curl callback helpers - -struct WriteCallbackData -{ - std::string* Body = nullptr; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<WriteCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return 0; // Signal abort to curl - } - - Data->Body->append(Ptr, TotalBytes); - return TotalBytes; -} - -struct HeaderCallbackData -{ - std::vector<std::pair<std::string, std::string>>* Headers = nullptr; -}; - -// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. -// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). -static std::optional<std::pair<std::string_view, std::string_view>> -ParseHeaderLine(std::string_view Line) -{ - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) - { - return std::nullopt; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos == std::string_view::npos) - { - return std::nullopt; - } - - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - - return std::pair{Key, Value}; -} - -static size_t -CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<HeaderCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) - { - auto& [Key, Value] = *Header; - Data->Headers->emplace_back(std::string(Key), std::string(Value)); - } - - return TotalBytes; -} - -struct ReadCallbackData -{ - const uint8_t* DataPtr = nullptr; - size_t DataSize = 0; - size_t Offset = 0; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<ReadCallbackData*>(UserData); - size_t MaxRead = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return CURL_READFUNC_ABORT; - } - - size_t Remaining = Data->DataSize - Data->Offset; - size_t ToRead = std::min(MaxRead, Remaining); - - if (ToRead > 0) - { - memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); - Data->Offset += ToRead; - } - - return ToRead; -} +// Curl callback helpers and shared utilities are in httpclientcurlhelpers.h struct StreamReadCallbackData { @@ -233,7 +88,7 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi { ZEN_UNUSED(Handle); LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); - auto Log = [&]() -> LoggerRef { return LogRef; }; + ZEN_SCOPED_LOG(LogRef); std::string_view DataView(Data, Size); @@ -281,120 +136,6 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi ////////////////////////////////////////////////////////////////////////// -static std::pair<std::string, std::string> -HeaderContentType(ZenContentType ContentType) -{ - return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); -} - -static curl_slist* -BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, - std::string_view SessionId, - const std::optional<std::string>& AccessToken, - const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) -{ - curl_slist* Headers = nullptr; - - for (const auto& [Key, Value] : *AdditionalHeader) - { - ExtendableStringBuilder<64> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - if (!SessionId.empty()) - { - ExtendableStringBuilder<64> SessionHeader; - SessionHeader << "UE-Session: " << SessionId; - Headers = curl_slist_append(Headers, SessionHeader.c_str()); - } - - if (AccessToken.has_value()) - { - ExtendableStringBuilder<128> AuthHeader; - AuthHeader << "Authorization: " << AccessToken.value(); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); - } - - for (const auto& [Key, Value] : ExtraHeaders) - { - ExtendableStringBuilder<128> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - return Headers; -} - -static HttpClient::KeyValueMap -BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) -{ - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HeaderMap; -} - -// Scans response headers for Content-Type and applies it to the buffer. -static void -ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) -{ - for (const auto& [Key, Value] : Headers) - { - if (StrCaseCompare(Key, "Content-Type") == 0) - { - Buffer.SetContentType(ParseContentType(Value)); - break; - } - } -} - -static void -AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) -{ - static constexpr char HexDigits[] = "0123456789ABCDEF"; - static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); - - for (char C : Input) - { - if (Unreserved.Contains(C)) - { - Out.Append(C); - } - else - { - uint8_t Byte = static_cast<uint8_t>(C); - char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; - Out.Append(std::string_view(Encoded, 3)); - } - } -} - -static void -BuildUrlWithParameters(StringBuilderBase& Url, - std::string_view BaseUrl, - std::string_view ResourcePath, - const HttpClient::KeyValueMap& Parameters) -{ - Url.Append(BaseUrl); - Url.Append(ResourcePath); - - if (!Parameters->empty()) - { - char Separator = '?'; - for (const auto& [Key, Value] : *Parameters) - { - Url.Append(Separator); - AppendUrlEncoded(Url, Key); - Url.Append('='); - AppendUrlEncoded(Url, Value); - Separator = '&'; - } - } -} - ////////////////////////////////////////////////////////////////////////// CurlHttpClient::CurlHttpClient(std::string_view BaseUri, @@ -440,9 +181,9 @@ CurlHttpClient::CurlResult CurlHttpClient::Session::PerformWithResponseCallbacks() { std::string Body; - WriteCallbackData WriteData{.Body = &Body, + CurlWriteCallbackData WriteData{.Body = &Body, .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; - HeaderCallbackData HdrData{}; + CurlHeaderCallbackData HdrData{}; std::vector<std::pair<std::string, std::string>> ResponseHeaders; HdrData.Headers = &ResponseHeaders; @@ -487,6 +228,13 @@ CurlHttpClient::Session::Perform() curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + char* EffectiveUrl = nullptr; + curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl); + if (EffectiveUrl) + { + Result.Url = EffectiveUrl; + } + return Result; } @@ -553,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'", SessionId, + Result.Url, static_cast<int>(Result.ErrorCode), Result.ErrorMessage); } @@ -699,9 +448,11 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result) { case CURLE_OK: break; + case CURLE_COULDNT_CONNECT: case CURLE_RECV_ERROR: case CURLE_SEND_ERROR: case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: return true; default: return false; @@ -748,10 +499,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult { if (Result.ErrorCode != CURLE_OK) { - ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", SessionId, static_cast<int>(MapCurlError(Result.ErrorCode)), Result.ErrorMessage, + static_cast<int>(Result.ErrorCode), Attempt, m_ConnectionSettings.RetryCount + 1); } @@ -856,6 +608,12 @@ CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& P // Disable signal handling for thread safety curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + if (m_ConnectionSettings.FollowRedirects) + { + curl_easy_setopt(Handle, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(Handle, CURLOPT_MAXREDIRS, static_cast<long>(m_ConnectionSettings.MaxRedirects)); + } + if (m_ConnectionSettings.ForbidReuseConnection) { curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); @@ -998,9 +756,9 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); @@ -1213,7 +971,7 @@ CurlHttpClient::Post(std::string_view Url, std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1266,7 +1024,7 @@ CurlHttpClient::Post(std::string_view Url, std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1367,9 +1125,9 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV return Sess.PerformWithResponseCallbacks(); } - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); @@ -1532,7 +1290,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", Data->TempFolderPath->string(), Ec.message()); @@ -1618,7 +1376,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); if (Ec) { - auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_SCOPED_LOG(Data->Log); ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", Data->TempFolderPath->string(), Ec.message()); diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index bdeb46633..ea9193e65 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -73,6 +73,7 @@ private: int64_t DownloadedBytes = 0; CURLcode ErrorCode = CURLE_OK; std::string ErrorMessage; + std::string Url; }; struct Session diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h new file mode 100644 index 000000000..cb5f5d9a9 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurlhelpers.h @@ -0,0 +1,298 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Shared helpers for curl-based HTTP client implementations (sync and async). +// This is an internal header, not part of the public API. + +#include <zencore/string.h> + +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Error mapping + +inline HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback data structures and callbacks + +struct CurlWriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct CurlHeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +inline std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) +{ + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return std::nullopt; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +inline size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct CurlReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +////////////////////////////////////////////////////////////////////////// +// +// URL and header construction + +inline void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +inline void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) +{ + Url.Append(BaseUrl); + Url.Append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); + Separator = '&'; + } + } +} + +inline std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +inline curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<std::string>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken.has_value()) + { + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken.value(); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + bool HasContentTypeOverride = AdditionalHeader->contains("Content-Type"); + for (const auto& [Key, Value] : ExtraHeaders) + { + if (HasContentTypeOverride && Key == "Content-Type") + { + continue; + } + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +inline HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +inline void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index ace7a3c7f..9d5846f71 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -305,6 +305,7 @@ HttpClient::Response::ToText() const case ZenContentType::kJavaScript: case ZenContentType::kJSON: case ZenContentType::kText: + case ZenContentType::kXML: case ZenContentType::kYAML: return std::string{AsText()}; @@ -520,7 +521,7 @@ MeasureLatency(HttpClient& Client, std::string_view Url) ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. - // Bail out immediately — retrying will just burn the connect timeout each time. + // Bail out immediately - retrying will just burn the connect timeout each time. if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) { break; diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index af653cbb2..deaeca2a8 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -194,7 +194,7 @@ public: "slow", [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { - Sleep(2000); + Sleep(100); Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); }); }, @@ -414,6 +414,17 @@ TEST_CASE("httpclient.post") CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); } + SUBCASE("POST with content type override via additional header") + { + const char* Payload = "test payload"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON, {{"Content-Type", "text/plain"}}); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload"); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + SUBCASE("POST with CbObject payload round-trip") { CbObjectWriter Writer; @@ -750,7 +761,9 @@ TEST_CASE("httpclient.error-handling") { SUBCASE("Connection refused") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); HttpClient::Response Resp = Client.Get("/api/test/hello"); CHECK(!Resp.IsSuccess()); CHECK(Resp.Error.has_value()); @@ -760,7 +773,7 @@ TEST_CASE("httpclient.error-handling") { TestServerFixture Fixture; HttpClientSettings Settings; - Settings.Timeout = std::chrono::milliseconds(500); + Settings.Timeout = std::chrono::milliseconds(50); HttpClient Client = Fixture.MakeClient(Settings); HttpClient::Response Resp = Client.Get("/api/test/slow"); @@ -970,7 +983,9 @@ TEST_CASE("httpclient.measurelatency") SUBCASE("Failed measurement against unreachable port") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); CHECK(!Result.Success); CHECK(!Result.FailureReason.empty()); @@ -1144,7 +1159,7 @@ struct FaultTcpServer ~FaultTcpServer() { // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this - // thread — ASIO I/O objects are not safe for concurrent access and the io_context + // thread - ASIO I/O objects are not safe for concurrent access and the io_context // thread may be touching the acceptor in StartAccept(). m_IoContext.stop(); if (m_Thread.joinable()) @@ -1498,7 +1513,7 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) std::atomic<bool> StallActive{true}; FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { DrainHttpRequest(Socket); - // Stop reading body — TCP window will fill and client send will stall + // Stop reading body - TCP window will fill and client send will stall while (StallActive.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(50)); @@ -1735,21 +1750,21 @@ TEST_CASE("httpclient.uri_decoding") TestServerFixture Fixture; HttpClient Client = Fixture.MakeClient(); - // URI without encoding — should pass through unchanged + // URI without encoding - should pass through unchanged { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt"); } - // Percent-encoded space — server should see decoded path + // Percent-encoded space - server should see decoded path { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt"); } - // Percent-encoded slash (%2F) — should be decoded to / + // Percent-encoded slash (%2F) - should be decoded to / { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt"); REQUIRE(Resp.IsSuccess()); @@ -1763,21 +1778,21 @@ TEST_CASE("httpclient.uri_decoding") CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt"); } - // No capture — echo/uri route returns just RelativeUri + // No capture - echo/uri route returns just RelativeUri { HttpClient::Response Resp = Client.Get("/api/test/echo/uri"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "echo/uri"); } - // Literal percent that is not an escape (%ZZ) — should be kept as-is + // Literal percent that is not an escape (%ZZ) - should be kept as-is { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt"); } - // Query params — raw values are returned as-is from GetQueryParams + // Query params - raw values are returned as-is from GetQueryParams { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test"); REQUIRE(Resp.IsSuccess()); @@ -1788,7 +1803,7 @@ TEST_CASE("httpclient.uri_decoding") { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3"); REQUIRE(Resp.IsSuccess()); - // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly + // GetQueryParams returns raw (still-encoded) values - callers must Decode() explicitly CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3"); } diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index c42841922..26a7298b3 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -50,8 +50,6 @@ namespace zen { namespace httpclientauth { IoBuffer Payload{IoBuffer::Wrap, Body.data(), Body.size()}; - // TODO: ensure this gets the right Content-Type passed along - HttpClient::Response Response = Http.Post("", Payload, {{"Content-Type", "application/x-www-form-urlencoded"}}); if (!Response || Response.StatusCode != HttpResponseCode::OK) @@ -94,7 +92,8 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Unattended, bool Quiet, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { Stopwatch Timer; @@ -117,8 +116,9 @@ namespace zen { namespace httpclientauth { } }); - const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}", OidcExecutablePath, + IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl", CloudHost, AuthTokenPath, Unattended ? "true"sv : "false"sv); @@ -193,7 +193,7 @@ namespace zen { namespace httpclientauth { } else { - ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode); } return HttpClientAccessToken{}; } @@ -202,9 +202,10 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { - HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); if (InitialToken.IsValid()) { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), @@ -212,12 +213,13 @@ namespace zen { namespace httpclientauth { Token = InitialToken, Quiet, Unattended, - Hidden]() mutable { + Hidden, + IsHordeUrl]() mutable { if (!Token.NeedsRefresh()) { return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); }; } return {}; diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 38021be16..03117ee6c 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -266,10 +266,10 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return false; } - const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); - const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); + const auto Start = ParseInt<uint64_t>(Token.substr(0, Delim)); + const auto End = ParseInt<uint64_t>(Token.substr(Delim + 1)); - if (Start.has_value() && End.has_value() && End.value() > Start.value()) + if (Start.has_value() && End.has_value() && End.value() >= Start.value()) { Ranges.push_back({.Start = Start.value(), .End = End.value()}); } @@ -286,6 +286,45 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return Count != Ranges.size(); } +MultipartByteRangesResult +BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges) +{ + Oid::String_t BoundaryStr; + Oid::NewOid().ToString(BoundaryStr); + std::string_view Boundary(BoundaryStr, Oid::StringLength); + + const uint64_t TotalSize = Data.GetSize(); + + std::vector<IoBuffer> Parts; + Parts.reserve(Ranges.size() * 2 + 1); + + for (const HttpRange& Range : Ranges) + { + uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + return {}; + } + + uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + + std::string PartHeader = fmt::format("\r\n--{}\r\nContent-Type: application/octet-stream\r\nContent-Range: bytes {}-{}/{}\r\n\r\n", + Boundary, + Range.Start, + RangeEnd, + TotalSize); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(PartHeader.data(), PartHeader.size())); + + IoBuffer RangeData(Data, Range.Start, RangeSize); + Parts.push_back(RangeData); + } + + std::string ClosingBoundary = fmt::format("\r\n--{}--", Boundary); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(ClosingBoundary.data(), ClosingBoundary.size())); + + return {.Parts = std::move(Parts), .ContentType = fmt::format("multipart/byteranges; boundary={}", Boundary)}; +} + ////////////////////////////////////////////////////////////////////////// const std::string_view @@ -564,6 +603,56 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType } void +HttpServerRequest::WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges) +{ + if (Ranges.empty()) + { + WriteResponse(HttpResponseCode::OK, ContentType, IoBuffer(Data)); + return; + } + + if (Ranges.size() == 1) + { + const HttpRange& Range = Ranges[0]; + const uint64_t TotalSize = Data.GetSize(); + // ~uint64_t(0) is the sentinel meaning "end of file" (suffix range). + const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + m_ContentRangeHeader = fmt::format("bytes */{}", TotalSize); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + + const uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + IoBuffer RangeBuf(Data, Range.Start, RangeSize); + + m_ContentRangeHeader = fmt::format("bytes {}-{}/{}", Range.Start, RangeEnd, TotalSize); + WriteResponse(HttpResponseCode::PartialContent, ContentType, std::move(RangeBuf)); + return; + } + + // Multi-range + MultipartByteRangesResult MultipartResult = BuildMultipartByteRanges(Data, Ranges); + if (MultipartResult.Parts.empty()) + { + m_ContentRangeHeader = fmt::format("bytes */{}", Data.GetSize()); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + WriteResponse(HttpResponseCode::PartialContent, std::move(MultipartResult.ContentType), std::span<IoBuffer>(MultipartResult.Parts)); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(ParseContentType(CustomContentType) == HttpContentType::kUnknownContentType); + m_ContentTypeOverride = CustomContentType; + WriteResponse(ResponseCode, HttpContentType::kBinary, Blobs); +} + +void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) { std::span<const SharedBuffer> Segments = Payload.GetSegments(); @@ -831,7 +920,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) // Strip the separator slash left over after the service prefix is removed. // When a service has BaseUri "/foo", the prefix length is set to len("/foo") = 4. - // Stripping 4 chars from "/foo/bar" yields "/bar" — the path separator becomes + // Stripping 4 chars from "/foo/bar" yields "/bar" - the path separator becomes // the first character of the relative URI. Remove it so patterns like "bar" or // "{id}" match without needing to account for the leading slash. if (!Uri.empty() && Uri.front() == '/') @@ -1532,7 +1621,7 @@ TEST_CASE("http.common") }, HttpVerb::kGet); - // Single-segment literal with leading slash — simulates real server RelativeUri + // Single-segment literal with leading slash - simulates real server RelativeUri { Reset(); TestHttpServerRequest req{Service, "/activity_counters"sv}; @@ -1552,7 +1641,7 @@ TEST_CASE("http.common") CHECK_EQ(Captures[0], "hello"sv); } - // Two-segment route with leading slash — first literal segment + // Two-segment route with leading slash - first literal segment { Reset(); TestHttpServerRequest req{Service, "/prefix/world"sv}; diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h new file mode 100644 index 000000000..cb41626b9 --- /dev/null +++ b/src/zenhttp/include/zenhttp/asynchttpclient.h @@ -0,0 +1,123 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> + +#include <functional> +#include <future> +#include <memory> + +namespace asio { +class io_context; +} + +namespace zen { + +/// Completion callback for async HTTP operations. +using AsyncHttpCallback = std::function<void(HttpClient::Response)>; + +/** Asynchronous HTTP client backed by curl_multi and ASIO. + * + * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process + * transfers without blocking the caller. All curl_multi operations are + * serialized on an internal strand; callers may issue requests from any + * thread, and the io_context may have multiple threads. + * + * Two construction modes: + * - Owned io_context: creates an internal thread (self-contained). + * - External io_context: caller runs the event loop. + * + * Completion callbacks are dispatched on the io_context (not the internal + * strand), so a slow callback will not block the curl poll loop. Future- + * based wrappers (Get, Post, ...) return a std::future<Response> for + * callers that prefer blocking on a result. + */ +class AsyncHttpClient +{ +public: + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + /// Construct with an internally-owned io_context and thread. + explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {}); + + /// Construct with an externally-managed io_context. The io_context must + /// outlive this client and must be running (via run()) on at least one thread. + AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {}); + + ~AsyncHttpClient(); + + AsyncHttpClient(const AsyncHttpClient&) = delete; + AsyncHttpClient& operator=(const AsyncHttpClient&) = delete; + + // -- Callback-based API ---------------------------------------------- + + void AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + void AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {}); + + // -- Future-based API ------------------------------------------------ + + [[nodiscard]] std::future<Response> Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, + const IoBuffer& Payload, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, const KeyValueMap& Parameters = {}); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void asynchttpclient_test_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index e199b700f..8da94524e 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -120,6 +120,13 @@ struct HttpClientSettings /// the system default CA store. std::string CaBundlePath; + /// Automatically follow HTTP 3xx redirects. When true, curl handles + /// redirects internally (up to MaxRedirects hops). Default: false. + bool FollowRedirects = false; + + /// Maximum number of redirects to follow when FollowRedirects is true. + int MaxRedirects = 5; + /// HTTP status codes that are expected and should not be logged as warnings. /// 404 is always treated as expected regardless of this list. std::vector<HttpResponseCode> ExpectedErrorCodes; diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index ce646ebd7..9220a50b6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -33,7 +33,8 @@ namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden); + bool Hidden, + bool IsHordeUrl = false); } // namespace httpclientauth } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index f9a99f3cc..1d921600d 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -19,8 +19,8 @@ class StringBuilderBase; struct HttpRange { - uint32_t Start = ~uint32_t(0); - uint32_t End = ~uint32_t(0); + uint64_t Start = ~uint64_t(0); + uint64_t End = ~uint64_t(0); }; using HttpRanges = std::vector<HttpRange>; @@ -30,6 +30,16 @@ extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeSt std::string_view ReasonStringForHttpResultCode(int HttpCode); bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); +struct MultipartByteRangesResult +{ + std::vector<IoBuffer> Parts; + std::string ContentType; +}; + +// Build a multipart/byteranges response body from the given data and ranges. +// Generates a unique boundary per call. Returns empty Parts if any range is out of bounds. +MultipartByteRangesResult BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges); + enum class HttpVerb : uint8_t { kGet = 1 << 0, diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 76f219f04..955b8ed15 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -122,11 +122,13 @@ public: virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + void WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs); void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + void WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges); virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; @@ -152,6 +154,8 @@ protected: std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); mutable Oid m_SessionId = Oid::Zero; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; inline void SetIsHandled() { m_Flags |= kIsHandled; } @@ -298,12 +302,12 @@ public: std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } - /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + /** Track active WebSocket connections - called by server implementations on upgrade/close. */ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } - /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + /** Track WebSocket frame and byte counters - called by WS connection implementations per frame. */ void OnWebSocketFrameReceived(uint64_t Bytes) { m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); @@ -325,7 +329,7 @@ private: int m_EffectiveHttpsPort = 0; std::string m_ExternalHost; metrics::Meter m_RequestMeter; - metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error — sufficient for client counting + metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error - sufficient for client counting metrics::HyperLogLog<12> m_ClientSessions; std::string m_DefaultRedirect; std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; @@ -518,7 +522,8 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -void http_forcelink(); // internal -void websocket_forcelink(); // internal +void http_forcelink(); // internal +void httpparser_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index bce771c75..51ab2e06e 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -23,11 +23,11 @@ namespace zen { class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - /// Construct without an io_context — optionally uses a dedicated push thread + /// Construct without an io_context - optionally uses a dedicated push thread /// for WebSocket stats broadcasting. explicit HttpStatsService(bool EnableWebSockets = false); - /// Construct with an external io_context — uses an asio timer instead + /// Construct with an external io_context - uses an asio timer instead /// of a dedicated thread for WebSocket stats broadcasting. /// The caller must ensure the io_context outlives this service and that /// its run loop is active. @@ -43,7 +43,7 @@ public: virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 9c3b909a2..fd2f79171 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -26,7 +26,7 @@ namespace zen { * Callback interface for WebSocket client events * * Separate from the server-side IWebSocketHandler because the caller - * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + * already owns the HttpWsClient - no Ref<WebSocketConnection> needed. */ class IWsClientHandler { @@ -85,9 +85,9 @@ private: /// it is treated as a plain host:port and gets the ws:// prefix. /// /// Examples: -/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws" -/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo" -/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar" +/// HttpToWsUrl("http://host:8080", "/orch/ws") -> "ws://host:8080/orch/ws" +/// HttpToWsUrl("https://host", "/foo") -> "wss://host/foo" +/// HttpToWsUrl("host:8080", "/bar") -> "ws://host:8080/bar" std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 710579faa..2d25515d3 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -59,7 +59,7 @@ class IWebSocketHandler public: virtual ~IWebSocketHandler() = default; - virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0; virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; }; diff --git a/src/zenhttp/include/zenhttp/zipfs.h b/src/zenhttp/include/zenhttp/zipfs.h new file mode 100644 index 000000000..c6acf7334 --- /dev/null +++ b/src/zenhttp/include/zenhttp/zipfs.h @@ -0,0 +1,35 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zencore/thread.h> + +#include <unordered_map> + +namespace zen { + +class ZipFs +{ +public: + explicit ZipFs(IoBuffer&& Buffer); + + IoBuffer GetFile(const std::string_view& FileName) const; + +private: + struct FileItem + { + MemoryView View; // Initially points to LFH (size=0); resolved to file data on first access + uint32_t CompressedSize = 0; + uint32_t UncompressedSize = 0; + uint16_t CompressionMethod = 0; + IoBuffer DecompressedData; // Owns decompressed buffer for deflate entries + }; + + using FileMap = std::unordered_map<std::string_view, FileItem>; + mutable RwLock m_FilesLock; + FileMap mutable m_Files; + IoBuffer m_Buffer; +}; + +} // namespace zen diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index 7e6207e56..5ad5ebcc7 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) // void -HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); ZEN_INFO("Stats WebSocket client connected"); diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 7972777b8..b624c3a29 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -625,6 +625,8 @@ public: void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } /** * Initialize the response for sending a payload made up of multiple blobs @@ -768,10 +770,18 @@ public: { ZEN_MEMSCOPE(GetHttpasioTag()); + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -898,7 +908,9 @@ private: bool m_AllowZeroCopyFileSend = true; State m_State = State::kUninitialized; HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::buffer(ResponseStr->data(), ResponseStr->size()), asio::bind_executor( m_Strand, - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()]( + const asio::error_code& Ec, + std::size_t) { if (Ec) { ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); @@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + std::string_view FullUrl = Conn->m_RequestData.Url(); + std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); })); @@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() return; } } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } if (!m_RequestData.IsKeepAlive()) @@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -2590,7 +2618,7 @@ HttpAsioServer::OnRun(bool IsInteractive) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractive) { @@ -2600,7 +2628,7 @@ HttpAsioServer::OnRun(bool IsInteractive) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 584e06cbf..196c0c142 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -88,7 +88,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractiveSession) { @@ -98,7 +98,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index 9bb7ef3bc..d698bcb9d 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -63,7 +63,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) } ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #else if (IsInteractiveSession) { @@ -73,7 +73,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) do { ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); #endif } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 918b55dc6..8b07c7905 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -8,6 +8,13 @@ #include <limits> +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <cstring> +# include <string> +# include <string_view> +#endif + namespace zen { using namespace std::literals; @@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W // HttpRequestParser // -http_parser_settings HttpRequestParser::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; +// clang-format off +llhttp_settings_t HttpRequestParser::s_ParserSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); }; + S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); +// clang-format on HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { - http_parser_init(&m_Parser, HTTP_REQUEST); + llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings); m_Parser.data = this; ResetState(); @@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser() size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize); + if (Err == HPE_OK) { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; + return DataSize; } - return ConsumedBytes; + if (Err == HPE_PAUSED_UPGRADE) + { + return DataSize; + } + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser)); + return ~0ull; } int @@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_UrlRange.Length == 0) @@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_HeaderEntries.empty()) @@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); @@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete() } } - m_KeepAlive = !!http_should_keep_alive(&m_Parser); + m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser); - switch (m_Parser.method) + switch (llhttp_get_method(&m_Parser)) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; @@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete() break; default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser)))); break; } @@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; + return -1; } memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - return 0; } @@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete() catch (const AssertException& AssertEx) { ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); - return 1; + return -1; } catch (const std::system_error& SystemError) { @@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete() ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); - return 1; + return -1; } catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); - return 1; + return -1; } catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); - return 1; + return -1; } } @@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; } +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace { + + struct MockCallbacks : HttpRequestParserCallbacks + { + int HandleRequestCount = 0; + int TerminateCount = 0; + + HttpRequestParser* Parser = nullptr; + + HttpVerb LastVerb{}; + std::string LastUrl; + std::string LastQueryString; + std::string LastBody; + bool LastKeepAlive = false; + bool LastIsWebSocketUpgrade = false; + std::string LastSecWebSocketKey; + std::string LastUpgradeHeader; + HttpContentType LastContentType{}; + + void HandleRequest() override + { + ++HandleRequestCount; + if (Parser) + { + LastVerb = Parser->RequestVerb(); + LastUrl = std::string(Parser->Url()); + LastQueryString = std::string(Parser->QueryString()); + LastKeepAlive = Parser->IsKeepAlive(); + LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade(); + LastSecWebSocketKey = std::string(Parser->SecWebSocketKey()); + LastUpgradeHeader = std::string(Parser->UpgradeHeader()); + LastContentType = Parser->ContentType(); + + IoBuffer Body = Parser->Body(); + if (Body.Size() > 0) + { + LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size()); + } + else + { + LastBody.clear(); + } + } + } + + void TerminateConnection() override { ++TerminateCount; } + }; + +} // anonymous namespace + +TEST_SUITE_BEGIN("http.httpparser"); + +TEST_CASE("httpparser.basic_get") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kGet); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK(Mock.LastKeepAlive); +} + +TEST_CASE("httpparser.post_with_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 13\r\n" + "Content-Type: application/json\r\n" + "\r\n" + "{\"key\":\"val\"}"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kPost); + CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}"); + CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON); +} + +TEST_CASE("httpparser.pipelined_requests") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n" + "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 2); + CHECK_EQ(Mock.LastUrl, "/second"); +} + +TEST_CASE("httpparser.partial_header") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc"; + std::string Chunk2 = "alhost\r\n\r\n"; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, Chunk2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); +} + +TEST_CASE("httpparser.partial_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Headers = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 10\r\n" + "\r\n"; + std::string BodyPart1 = "hello"; + std::string BodyPart2 = "world"; + + std::string Chunk1 = Headers + BodyPart1; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, BodyPart2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "helloworld"); +} + +TEST_CASE("httpparser.invalid_request") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Garbage = "NOT_HTTP garbage data\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 0); +} + +TEST_CASE("httpparser.body_overflow") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes, + // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY" + // as a new HTTP request which fails. + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 2\r\n" + "\r\n" + "TOO_LONG_BODY"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "TO"); +} + +TEST_CASE("httpparser.websocket_upgrade") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); + CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(Mock.LastUpgradeHeader, "websocket"); +} + +TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string HttpPart = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + // Append fake WebSocket frame bytes after the HTTP message + std::string Request = HttpPart; + Request.push_back('\x81'); + Request.push_back('\x05'); + Request.append("hello"); + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_NE(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); +} + +TEST_CASE("httpparser.keep_alive_detection") +{ + SUBCASE("HTTP/1.1 default keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK(Mock.LastKeepAlive); + } + + SUBCASE("Connection: close disables keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK_FALSE(Mock.LastKeepAlive); + } +} + +TEST_CASE("httpparser.all_verbs") +{ + struct VerbTest + { + const char* Method; + HttpVerb Expected; + }; + + VerbTest Tests[] = { + {"GET", HttpVerb::kGet}, + {"POST", HttpVerb::kPost}, + {"PUT", HttpVerb::kPut}, + {"DELETE", HttpVerb::kDelete}, + {"HEAD", HttpVerb::kHead}, + {"COPY", HttpVerb::kCopy}, + {"OPTIONS", HttpVerb::kOptions}, + }; + + for (const VerbTest& Test : Tests) + { + CAPTURE(Test.Method); + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, Test.Expected); + } +} + +TEST_CASE("httpparser.query_string") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK_EQ(Mock.LastQueryString, "key=val&other=123"); +} + +TEST_SUITE_END(); + +void +httpparser_forcelink() +{ +} + +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 23ad9d8fb..4ff216248 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -8,7 +8,7 @@ #include <EASTL/fixed_vector.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <http_parser.h> +#include <llhttp.h> ZEN_THIRD_PARTY_INCLUDES_END #include <atomic> @@ -100,7 +100,7 @@ private: Oid m_SessionId{}; IoBuffer m_BodyBuffer; uint64_t m_BodyPosition = 0; - http_parser m_Parser; + llhttp_t m_Parser; eastl::fixed_vector<char, 512> m_HeaderData; std::string m_NormalizedUrl; @@ -114,8 +114,8 @@ private: int OnBody(const char* Data, size_t Bytes); int OnMessageComplete(); - static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } - static http_parser_settings s_ParserSettings; + static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } + static llhttp_settings_t s_ParserSettings; }; } // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 31b0315d4..ad7ed259a 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -185,13 +185,17 @@ public: const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } void SuppressPayload() { m_ResponseBuffers.resize(1); } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } std::string_view GetHeaders(); private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; uint64_t m_ContentLength = 0; std::vector<IoBuffer> m_ResponseBuffers; ExtendableStringBuilder<160> m_Headers; @@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders() if (m_Headers.Size() == 0) { + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(ContentType)); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -831,6 +855,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (plugin HTTP). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { if (IsInteractive && _kbhit() != 0) @@ -844,18 +869,19 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested && !IsApplicationExitRequested()); # else if (IsInteractive) { ZEN_CONSOLE("Zen Server running (plugin HTTP). Ctrl-C to quit"); } + bool ShutdownRequested = false; do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested && !IsApplicationExitRequested()); # endif } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 2cad97725..c1b426bea 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -464,6 +464,8 @@ public: inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -473,6 +475,8 @@ private: uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; std::string m_LocationHeader; @@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + std::string_view ContentTypeString = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); @@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); } + // Content-Range header (for 206 Partial Content single-range responses) + + if (!m_ContentRangeHeader.empty()) + { + PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange]; + ContentRangeHeader->pRawValue = m_ContentRangeHeader.data(); + ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) { // Port may be owned by another process's wildcard registration (access denied) - // or actively in use (sharing violation) — retry on a different port + // or actively in use (sharing violation) - retry on a different port ShouldRetryNextPort = true; } else @@ -1713,7 +1727,7 @@ HttpSysServer::OnRun(bool IsInteractive) ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!ShutdownRequested); + } while (!ShutdownRequested && !IsApplicationExitRequested()); } void @@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + if (!m_ContentTypeOverride.empty()) + { + Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT &Transaction().Server())); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + ExtendableStringBuilder<128> UrlUtf8; + WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath, + gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))}, + UrlUtf8); + int PrefixLen = Service->UriPrefixLength(); + std::string_view RelativeUri{UrlUtf8.ToView()}; + RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); return nullptr; @@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - // WebSocket upgrade failed — return nullptr since ServerRequest() + // WebSocket upgrade failed - return nullptr since ServerRequest() // was never populated (no InvokeRequestHandler call) return nullptr; } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } } diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 5ae48f5b3..078c21ea1 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index af320172d..8520e9f60 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown() return; } - // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } @@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: @@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); - // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 59c46a418..a58037fec 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -5,6 +5,7 @@ # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/timer.h> # include <zenhttp/httpserver.h> # include <zenhttp/httpwsclient.h> @@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec") std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); - // Server frames are unmasked — TryParseFrame should handle them + // Server frames are unmasked - TryParseFrame should handle them WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); @@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec") { std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); - // Pass only 1 byte — not enough for a frame header + // Pass only 1 byte - not enough for a frame header WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); CHECK_FALSE(Result.IsValid); CHECK_EQ(Result.BytesConsumed, 0u); @@ -335,8 +336,9 @@ namespace { } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); @@ -463,7 +465,7 @@ namespace { if (!Done.load()) { - // Timeout — cancel the read + // Timeout - cancel the read asio::error_code Ec; Sock.cancel(Ec); } @@ -476,6 +478,23 @@ namespace { return Result; } + static void WaitForServerListening(int Port) + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + asio::io_context IoCtx; + asio::ip::tcp::socket Probe(IoCtx); + asio::error_code Ec; + Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec); + if (!Ec) + { + return; + } + Sleep(10); + } + } + } // anonymous namespace TEST_CASE("websocket.integration") @@ -501,8 +520,8 @@ TEST_CASE("websocket.integration") Server->Close(); }); - // Give server a moment to start accepting - Sleep(100); + // Wait for server to start accepting + WaitForServerListening(Port); SUBCASE("handshake succeeds with 101") { @@ -692,7 +711,7 @@ TEST_CASE("websocket.integration") std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); - // Should NOT get 101 — should fall through to normal request handling + // Should NOT get 101 - should fall through to normal request handling CHECK(Response.find("101") == std::string::npos); Sock.close(); @@ -813,7 +832,7 @@ TEST_CASE("websocket.client") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close") { @@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close over unix socket") { diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 7b050ae35..b2b813036 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -9,7 +9,7 @@ target('zenhttp') add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio") - add_packages("http_parser", "json11", "libcurl") + add_packages("llhttp", "json11", "libcurl", "zlib") add_options("httpsys") if is_plat("linux", "macosx") then diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 3ac8eea8d..1317f0159 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -4,6 +4,7 @@ #if ZEN_WITH_TESTS +# include <zenhttp/asynchttpclient.h> # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> # include <zenhttp/packageformat.h> @@ -11,15 +12,20 @@ namespace zen { +void zipfs_test_forcelink(); + void zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpparser_forcelink(); httpclient_test_forcelink(); + asynchttpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); websocket_forcelink(); + zipfs_test_forcelink(); } } // namespace zen diff --git a/src/zenhttp/zipfs.cpp b/src/zenhttp/zipfs.cpp new file mode 100644 index 000000000..c0ffa2052 --- /dev/null +++ b/src/zenhttp/zipfs.cpp @@ -0,0 +1,228 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/zipfs.h" + +#include <zencore/logging.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <zlib.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +namespace { + +#if ZEN_COMPILER_MSC +# pragma warning(push) +# pragma warning(disable : 4200) +#endif + + using ZipInt16 = uint16_t; + + struct ZipInt32 + { + operator uint32_t() const { return *(uint32_t*)Parts; } + uint16_t Parts[2]; + }; + + struct EocdRecord + { + enum : uint32_t + { + Magic = 0x0605'4b50, + }; + ZipInt32 Signature; + ZipInt16 ThisDiskIndex; + ZipInt16 CdStartDiskIndex; + ZipInt16 CdRecordThisDiskCount; + ZipInt16 CdRecordCount; + ZipInt32 CdSize; + ZipInt32 CdOffset; + ZipInt16 CommentSize; + char Comment[]; + }; + + struct CentralDirectoryRecord + { + enum : uint32_t + { + Magic = 0x0201'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionMadeBy; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + ZipInt16 CommentLength; + ZipInt16 DiskIndex; + ZipInt16 InternalFileAttr; + ZipInt32 ExternalFileAttr; + ZipInt32 Offset; + char FileName[]; + }; + + struct LocalFileHeader + { + enum : uint32_t + { + Magic = 0x0403'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + char FileName[]; + }; + +#if ZEN_COMPILER_MSC +# pragma warning(pop) +#endif + +} // namespace + +////////////////////////////////////////////////////////////////////////// +ZipFs::ZipFs(IoBuffer&& Buffer) +{ + MemoryView View = Buffer.GetView(); + + uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize(); + if (View.GetSize() < sizeof(EocdRecord)) + { + return; + } + + const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord)); + + // It is more correct to search backwards for EocdRecord::Magic as the + // comment can be of a variable length. But here we're not going to support + // zip files with comments. + if (EocdCursor->Signature != EocdRecord::Magic) + { + return; + } + + // Zip64 isn't supported either + if (EocdCursor->ThisDiskIndex == 0xffff) + { + return; + } + + Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize); + + const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset); + for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i) + { + const CentralDirectoryRecord& Cd = *CdCursor; + + bool Acceptable = true; + Acceptable &= (Cd.OriginalSize > 0); // has some content + Acceptable &= (Cd.CompressionMethod == 0 || Cd.CompressionMethod == 8); // stored or deflate + if (Acceptable) + { + const uint8_t* Lfh = Cursor + Cd.Offset; + if (uintptr_t(Lfh - Cursor) < View.GetSize()) + { + std::string_view FileName(Cd.FileName, Cd.FileNameLength); + FileItem Item; + Item.View = MemoryView{Lfh, size_t(0)}; + Item.CompressionMethod = Cd.CompressionMethod; + Item.CompressedSize = Cd.CompressedSize; + Item.UncompressedSize = Cd.OriginalSize; + m_Files.insert(std::make_pair(FileName, std::move(Item))); + } + } + + uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength; + CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes); + } + + m_Buffer = std::move(Buffer); +} + +////////////////////////////////////////////////////////////////////////// +IoBuffer +ZipFs::GetFile(const std::string_view& FileName) const +{ + { + RwLock::SharedLockScope _(m_FilesLock); + + FileMap::const_iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + const FileItem& Item = Iter->second; + if (Item.View.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.View.GetData(), Item.View.GetSize()); + } + } + + RwLock::ExclusiveLockScope _(m_FilesLock); + + FileItem& Item = m_Files.find(FileName)->second; + if (Item.View.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.View.GetData(), Item.View.GetSize()); + } + + const auto* Lfh = (LocalFileHeader*)(Item.View.GetData()); + const uint8_t* FileData = (const uint8_t*)(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength); + + if (Item.CompressionMethod == 0) + { + // Stored - point directly into the buffer + Item.View = MemoryView(FileData, Item.UncompressedSize); + } + else + { + // Deflate - decompress using zlib + Item.DecompressedData = IoBuffer(Item.UncompressedSize); + + z_stream Stream = {}; + Stream.next_in = const_cast<Bytef*>(FileData); + Stream.avail_in = Item.CompressedSize; + Stream.next_out = (Bytef*)Item.DecompressedData.GetMutableView().GetData(); + Stream.avail_out = Item.UncompressedSize; + + // Use raw inflate (-MAX_WBITS) since zip stores raw deflate streams + if (inflateInit2(&Stream, -MAX_WBITS) != Z_OK) + { + ZEN_WARN("failed to initialize inflate for '{}'", FileName); + return {}; + } + + int Result = inflate(&Stream, Z_FINISH); + inflateEnd(&Stream); + + if (Result != Z_STREAM_END) + { + ZEN_WARN("failed to decompress '{}' (zlib error {})", FileName, Result); + return {}; + } + + Item.View = Item.DecompressedData.GetView(); + } + + return IoBuffer(IoBuffer::Wrap, Item.View.GetData(), Item.View.GetSize()); +} + +} // namespace zen diff --git a/src/zenhttp/zipfs_test.cpp b/src/zenhttp/zipfs_test.cpp new file mode 100644 index 000000000..b3a45c408 --- /dev/null +++ b/src/zenhttp/zipfs_test.cpp @@ -0,0 +1,221 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/zipfs.h" + +#include <zencore/iobuffer.h> + +#if ZEN_WITH_TESTS + +ZEN_THIRD_PARTY_INCLUDES_START +# include <doctest/doctest.h> +# include <zlib.h> +ZEN_THIRD_PARTY_INCLUDES_END + +# include <cstring> +# include <vector> + +namespace zen { +void +zipfs_test_forcelink() +{ +} +} // namespace zen + +TEST_SUITE_BEGIN("http.zipfs"); + +namespace { + +// Helpers to build a minimal zip file in memory +struct ZipBuilder +{ + std::vector<uint8_t> Data; + + struct Entry + { + std::string Name; + uint32_t LocalHeaderOffset; + uint16_t CompressionMethod; + uint32_t CompressedSize; + uint32_t UncompressedSize; + }; + + std::vector<Entry> Entries; + + void Append(const void* Src, size_t Size) + { + const uint8_t* Bytes = (const uint8_t*)Src; + Data.insert(Data.end(), Bytes, Bytes + Size); + } + + void AppendU16(uint16_t V) { Append(&V, 2); } + void AppendU32(uint32_t V) { Append(&V, 4); } + + void AddFile(const std::string& Name, const void* Content, size_t ContentSize, bool Deflate) + { + std::vector<uint8_t> FileData; + uint16_t Method = 0; + + if (Deflate) + { + // Compress with raw deflate (no zlib/gzip header) + uLongf BoundSize = compressBound((uLong)ContentSize); + std::vector<uint8_t> TempBuf(BoundSize); + + z_stream Stream = {}; + Stream.next_in = (Bytef*)Content; + Stream.avail_in = (uInt)ContentSize; + Stream.next_out = TempBuf.data(); + Stream.avail_out = (uInt)TempBuf.size(); + + deflateInit2(&Stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, -MAX_WBITS, 8, Z_DEFAULT_STRATEGY); + deflate(&Stream, Z_FINISH); + deflateEnd(&Stream); + + TempBuf.resize(Stream.total_out); + FileData = std::move(TempBuf); + Method = 8; + } + else + { + FileData.assign((const uint8_t*)Content, (const uint8_t*)Content + ContentSize); + } + + Entry E; + E.Name = Name; + E.LocalHeaderOffset = (uint32_t)Data.size(); + E.CompressionMethod = Method; + E.CompressedSize = (uint32_t)FileData.size(); + E.UncompressedSize = (uint32_t)ContentSize; + Entries.push_back(E); + + // Local file header + AppendU32(0x04034b50); // signature + AppendU16(20); // version needed + AppendU16(0); // flags + AppendU16(Method); // compression method + AppendU16(0); // last mod time + AppendU16(0); // last mod date + AppendU32(0); // crc32 (not validated by ZipFs) + AppendU32(E.CompressedSize); // compressed size + AppendU32(E.UncompressedSize); // uncompressed size + AppendU16((uint16_t)Name.size()); // file name length + AppendU16(0); // extra field length + Append(Name.data(), Name.size()); // file name + Append(FileData.data(), FileData.size()); + } + + zen::IoBuffer Build() + { + uint32_t CdOffset = (uint32_t)Data.size(); + + for (const Entry& E : Entries) + { + // Central directory record + AppendU32(0x02014b50); // signature + AppendU16(20); // version made by + AppendU16(20); // version needed + AppendU16(0); // flags + AppendU16(E.CompressionMethod); // compression method + AppendU16(0); // last mod time + AppendU16(0); // last mod date + AppendU32(0); // crc32 + AppendU32(E.CompressedSize); // compressed size + AppendU32(E.UncompressedSize); // uncompressed size + AppendU16((uint16_t)E.Name.size()); // file name length + AppendU16(0); // extra field length + AppendU16(0); // comment length + AppendU16(0); // disk index + AppendU16(0); // internal file attr + AppendU32(0); // external file attr + AppendU32(E.LocalHeaderOffset); // offset + Append(E.Name.data(), E.Name.size()); + } + + uint32_t CdSize = (uint32_t)Data.size() - CdOffset; + + // End of central directory record + AppendU32(0x06054b50); // signature + AppendU16(0); // this disk + AppendU16(0); // cd start disk + AppendU16((uint16_t)Entries.size()); // cd records this disk + AppendU16((uint16_t)Entries.size()); // cd records total + AppendU32(CdSize); // cd size + AppendU32(CdOffset); // cd offset + AppendU16(0); // comment length + + zen::IoBuffer Buffer(Data.size()); + std::memcpy(Buffer.GetMutableView().GetData(), Data.data(), Data.size()); + return Buffer; + } +}; + +} // namespace + +TEST_CASE("zipfs.stored") +{ + const char* Content = "Hello, World!"; + + ZipBuilder Zip; + Zip.AddFile("test.txt", Content, std::strlen(Content), false); + + zen::ZipFs Fs(Zip.Build()); + + zen::IoBuffer Result = Fs.GetFile("test.txt"); + REQUIRE(Result); + CHECK(Result.GetView().GetSize() == std::strlen(Content)); + CHECK(std::memcmp(Result.GetView().GetData(), Content, std::strlen(Content)) == 0); +} + +TEST_CASE("zipfs.deflate") +{ + const char* Content = "This is some content that will be deflate compressed in the zip file."; + + ZipBuilder Zip; + Zip.AddFile("compressed.txt", Content, std::strlen(Content), true); + + zen::ZipFs Fs(Zip.Build()); + + zen::IoBuffer Result = Fs.GetFile("compressed.txt"); + REQUIRE(Result); + CHECK(Result.GetView().GetSize() == std::strlen(Content)); + CHECK(std::memcmp(Result.GetView().GetData(), Content, std::strlen(Content)) == 0); +} + +TEST_CASE("zipfs.mixed") +{ + const char* StoredContent = "stored content"; + const char* DeflateContent = "deflate content that is compressed"; + + ZipBuilder Zip; + Zip.AddFile("stored.txt", StoredContent, std::strlen(StoredContent), false); + Zip.AddFile("deflated.txt", DeflateContent, std::strlen(DeflateContent), true); + + zen::ZipFs Fs(Zip.Build()); + + zen::IoBuffer Stored = Fs.GetFile("stored.txt"); + REQUIRE(Stored); + CHECK(Stored.GetView().GetSize() == std::strlen(StoredContent)); + CHECK(std::memcmp(Stored.GetView().GetData(), StoredContent, std::strlen(StoredContent)) == 0); + + zen::IoBuffer Deflated = Fs.GetFile("deflated.txt"); + REQUIRE(Deflated); + CHECK(Deflated.GetView().GetSize() == std::strlen(DeflateContent)); + CHECK(std::memcmp(Deflated.GetView().GetData(), DeflateContent, std::strlen(DeflateContent)) == 0); +} + +TEST_CASE("zipfs.not_found") +{ + const char* Content = "data"; + + ZipBuilder Zip; + Zip.AddFile("exists.txt", Content, std::strlen(Content), false); + + zen::ZipFs Fs(Zip.Build()); + + zen::IoBuffer Result = Fs.GetFile("missing.txt"); + CHECK(!Result); +} + +TEST_SUITE_END(); + +#endif // ZEN_WITH_TESTS |