diff options
Diffstat (limited to 'src/zenhttp')
30 files changed, 3044 insertions, 435 deletions
diff --git a/src/zenhttp/asynchttpclient_test.cpp b/src/zenhttp/asynchttpclient_test.cpp new file mode 100644 index 000000000..151863370 --- /dev/null +++ b/src/zenhttp/asynchttpclient_test.cpp @@ -0,0 +1,315 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Reusable test service for async client tests + +class AsyncHttpClientTestService : public HttpService +{ +public: + AsyncHttpClientTestService() + { + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"ok\":true}"); + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/async-test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + +private: + HttpRequestRouter m_Router; +}; + +////////////////////////////////////////////////////////////////////////// + +struct AsyncTestServerFixture +{ + AsyncHttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + AsyncTestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(0, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~AsyncTestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + AsyncHttpClient MakeClient(HttpClientSettings Settings = {}) { return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), Settings); } + + AsyncHttpClient MakeClient(asio::io_context& IoContext, HttpClientSettings Settings = {}) + { + return AsyncHttpClient(fmt::format("127.0.0.1:{}", Port), IoContext, Settings); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.asynchttpclient"); + +TEST_CASE("asynchttpclient.future.verbs") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + auto Future = Client.Get("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + auto Future = Client.Post("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + auto Future = Client.Put("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + auto Future = Client.Delete("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + auto Future = Client.Head("/api/async-test/echo/method"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("asynchttpclient.future.get") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET returning JSON") + { + auto Future = Client.Get("/api/async-test/json"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"ok\":true}"); + } + + SUBCASE("GET 204 NoContent") + { + auto Future = Client.Get("/api/async-test/nocontent"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("asynchttpclient.future.post.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PayloadStr = "async payload data"; + IoBuffer Payload(IoBuffer::Clone, PayloadStr.data(), PayloadStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Post("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "async payload data"); +} + +TEST_CASE("asynchttpclient.future.put.with.payload") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::string_view PutStr = "put payload"; + IoBuffer Payload(IoBuffer::Clone, PutStr.data(), PutStr.size()); + Payload.SetContentType(ZenContentType::kText); + + auto Future = Client.Put("/api/async-test/echo", Payload); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload"); +} + +TEST_CASE("asynchttpclient.callback") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + std::promise<HttpClient::Response> Promise; + auto Future = Promise.get_future(); + + Client.AsyncGet("/api/async-test/hello", [&Promise](HttpClient::Response Resp) { Promise.set_value(std::move(Resp)); }); + + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); +} + +TEST_CASE("asynchttpclient.concurrent.requests") +{ + AsyncTestServerFixture Fixture; + AsyncHttpClient Client = Fixture.MakeClient(); + + // Fire multiple requests concurrently + auto Future1 = Client.Get("/api/async-test/hello"); + auto Future2 = Client.Get("/api/async-test/json"); + auto Future3 = Client.Post("/api/async-test/echo/method"); + auto Future4 = Client.Delete("/api/async-test/echo/method"); + + auto Resp1 = Future1.get(); + auto Resp2 = Future2.get(); + auto Resp3 = Future3.get(); + auto Resp4 = Future4.get(); + + CHECK(Resp1.IsSuccess()); + CHECK_EQ(Resp1.AsText(), "hello world"); + + CHECK(Resp2.IsSuccess()); + CHECK_EQ(Resp2.AsText(), "{\"ok\":true}"); + + CHECK(Resp3.IsSuccess()); + CHECK_EQ(Resp3.AsText(), "POST"); + + CHECK(Resp4.IsSuccess()); + CHECK_EQ(Resp4.AsText(), "DELETE"); +} + +TEST_CASE("asynchttpclient.external.io_context") +{ + AsyncTestServerFixture Fixture; + + asio::io_context IoContext; + auto WorkGuard = asio::make_work_guard(IoContext); + std::thread IoThread([&IoContext]() { IoContext.run(); }); + + { + AsyncHttpClient Client = Fixture.MakeClient(IoContext); + + auto Future = Client.Get("/api/async-test/hello"); + auto Resp = Future.get(); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + WorkGuard.reset(); + IoThread.join(); +} + +TEST_CASE("asynchttpclient.connection.error") +{ + // Connect to a port where nothing is listening + AsyncHttpClient Client("127.0.0.1:1", HttpClientSettings{.ConnectTimeout = std::chrono::milliseconds(500)}); + + auto Future = Client.Get("/should-fail"); + auto Resp = Future.get(); + + CHECK_FALSE(Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + CHECK(Resp.Error->IsConnectionError()); +} + +TEST_SUITE_END(); + +void +asynchttpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/clients/asynchttpclient.cpp b/src/zenhttp/clients/asynchttpclient.cpp new file mode 100644 index 000000000..4b189af36 --- /dev/null +++ b/src/zenhttp/clients/asynchttpclient.cpp @@ -0,0 +1,1033 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/asynchttpclient.h> + +#include "httpclientcurlhelpers.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <thread> +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// TransferContext: per-transfer state associated with each CURL easy handle + +struct TransferContext +{ + AsyncHttpCallback Callback; + std::string Body; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + CurlWriteCallbackData WriteData; + CurlHeaderCallbackData HeaderData; + curl_slist* HeaderList = nullptr; + + // For PUT/POST with payload: keep the data alive until transfer completes + IoBuffer PayloadBuffer; + CurlReadCallbackData ReadData; + + TransferContext(AsyncHttpCallback&& InCallback) : Callback(std::move(InCallback)) + { + WriteData.Body = &Body; + HeaderData.Headers = &ResponseHeaders; + } + + ~TransferContext() + { + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + } + + TransferContext(const TransferContext&) = delete; + TransferContext& operator=(const TransferContext&) = delete; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient::Impl + +struct AsyncHttpClient::Impl +{ + Impl(std::string_view BaseUri, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + m_WorkGuard.emplace(m_IoContext.get_executor()); + m_IoThread = std::thread([this]() { + SetCurrentThreadName("async_http"); + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("AsyncHttpClient: unhandled exception in io thread: {}", Ex.what()); + } + }); + } + + Impl(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) + : m_BaseUri(BaseUri) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + , m_Strand(asio::make_strand(m_IoContext)) + , m_Timer(m_Strand) + { + Init(); + } + + ~Impl() + { + // Clean up curl state on the strand where all curl_multi operations + // are serialized. Use a promise to block until the cleanup handler + // has actually executed - essential for the external io_context case + // where we don't own the run loop. + std::promise<void> Done; + std::future<void> DoneFuture = Done.get_future(); + + asio::post(m_Strand, [this, &Done]() { + m_ShuttingDown = true; + m_Timer.cancel(); + + // Release all tracked sockets (don't close - curl owns the fds). + for (auto& [Fd, Info] : m_Sockets) + { + if (Info->Socket.is_open()) + { + Info->Socket.cancel(); + Info->Socket.release(); + } + } + m_Sockets.clear(); + + for (auto& [Handle, Ctx] : m_Transfers) + { + curl_multi_remove_handle(m_Multi, Handle); + curl_easy_cleanup(Handle); + } + m_Transfers.clear(); + + for (CURL* Handle : m_HandlePool) + { + curl_easy_cleanup(Handle); + } + m_HandlePool.clear(); + + Done.set_value(); + }); + + // For owned io_context: release work guard so run() can return after + // processing the cleanup handler above. + m_WorkGuard.reset(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + else + { + // External io_context: wait for the cleanup handler to complete. + DoneFuture.wait(); + } + + if (m_Multi) + { + curl_multi_cleanup(m_Multi); + } + } + + LoggerRef Log() { return m_Log; } + + void Init() + { + m_Multi = curl_multi_init(); + if (!m_Multi) + { + throw std::runtime_error("curl_multi_init failed"); + } + + SetupMultiCallbacks(); + + if (m_Settings.SessionId == Oid::Zero) + { + m_SessionId = std::string(GetSessionIdString()); + } + else + { + m_SessionId = m_Settings.SessionId.ToString(); + } + } + + // -- Handle pool ----------------------------------------------------- + + CURL* AllocHandle() + { + if (!m_HandlePool.empty()) + { + CURL* Handle = m_HandlePool.back(); + m_HandlePool.pop_back(); + curl_easy_reset(Handle); + return Handle; + } + CURL* Handle = curl_easy_init(); + if (!Handle) + { + throw std::runtime_error("curl_easy_init failed"); + } + return Handle; + } + + void ReleaseHandle(CURL* Handle) { m_HandlePool.push_back(Handle); } + + // -- Configure a handle with common settings ------------------------- + // Called only from DoAsync* lambdas running on the strand. + + void ConfigureHandle(CURL* Handle, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) + { + // Build URL + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Unix domain socket + if (!m_Settings.UnixSocketPath.empty()) + { + m_UnixSocketPathUtf8 = PathToUtf8(m_Settings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, m_UnixSocketPathUtf8.c_str()); + } + + // Timeouts + if (m_Settings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_Settings.ConnectTimeout.count())); + } + if (m_Settings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_Settings.Timeout.count())); + } + + // HTTP/2 + if (m_Settings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // SSL + if (m_Settings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!m_Settings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_Settings.CaBundlePath.c_str()); + } + + // Verbose/debug + if (m_Settings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + } + + // Thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (m_Settings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + } + + // -- Access token ---------------------------------------------------- + + std::optional<std::string> GetAccessToken() + { + if (!m_Settings.AccessTokenProvider.has_value()) + { + return {}; + } + { + RwLock::SharedLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + } + RwLock::ExclusiveLockScope _(m_AccessTokenLock); + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewToken = m_Settings.AccessTokenProvider.value()(); + if (!NewToken.IsValid()) + { + ZEN_WARN("AsyncHttpClient: failed to refresh access token, retrying once"); + NewToken = m_Settings.AccessTokenProvider.value()(); + } + if (NewToken.IsValid()) + { + m_CachedAccessToken = NewToken; + return m_CachedAccessToken.GetValue(); + } + ZEN_WARN("AsyncHttpClient: access token provider returned invalid token"); + return {}; + } + + // -- Submit a transfer ----------------------------------------------- + + void SubmitTransfer(CURL* Handle, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::SubmitTransfer"); + // Setup write/header callbacks + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &Ctx->WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &Ctx->HeaderData); + + m_Transfers[Handle] = std::move(Ctx); + + CURLMcode Mc = curl_multi_add_handle(m_Multi, Handle); + if (Mc != CURLM_OK) + { + auto Stolen = std::move(m_Transfers[Handle]); + m_Transfers.erase(Handle); + ReleaseHandle(Handle); + + HttpClient::Response ErrorResponse; + ErrorResponse.Error = + HttpClient::ErrorContext{.ErrorCode = HttpClientErrorCode::kInternalError, + .ErrorMessage = fmt::format("curl_multi_add_handle failed: {}", curl_multi_strerror(Mc))}; + asio::post(m_IoContext, + [Cb = std::move(Stolen->Callback), Response = std::move(ErrorResponse)]() mutable { Cb(std::move(Response)); }); + return; + } + } + + // -- Socket-action integration --------------------------------------- + // + // curl_multi drives I/O via two callbacks: + // - SocketCallback: curl tells us which sockets to watch for read/write + // - TimerCallback: curl tells us when to fire a timeout + // + // On each socket event or timeout we call curl_multi_socket_action(), + // then drain completed transfers via curl_multi_info_read(). + + // Per-socket state: wraps the native fd in an ASIO socket for async_wait. + struct SocketInfo + { + asio::ip::tcp::socket Socket; + int WatchFlags = 0; // CURL_POLL_IN, CURL_POLL_OUT, CURL_POLL_INOUT + + explicit SocketInfo(asio::io_context& IoContext) : Socket(IoContext) {} + }; + + // Static thunks registered with curl_multi ---------------------------- + + static int CurlSocketCallback(CURL* Easy, curl_socket_t Fd, int Action, void* UserPtr, void* SocketPtr) + { + ZEN_UNUSED(Easy); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlSocket(Fd, Action, static_cast<SocketInfo*>(SocketPtr)); + return 0; + } + + static int CurlTimerCallback(CURLM* Multi, long TimeoutMs, void* UserPtr) + { + ZEN_UNUSED(Multi); + auto* Self = static_cast<Impl*>(UserPtr); + Self->OnCurlTimer(TimeoutMs); + return 0; + } + + void SetupMultiCallbacks() + { + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETFUNCTION, CurlSocketCallback); + curl_multi_setopt(m_Multi, CURLMOPT_SOCKETDATA, this); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERFUNCTION, CurlTimerCallback); + curl_multi_setopt(m_Multi, CURLMOPT_TIMERDATA, this); + } + + // Called by curl when socket watch state changes --------------------- + + void OnCurlSocket(curl_socket_t Fd, int Action, SocketInfo* Info) + { + if (Action == CURL_POLL_REMOVE) + { + if (Info) + { + // Cancel pending async_wait ops before releasing the fd. + // curl owns the fd, so we must release() rather than close(). + Info->Socket.cancel(); + if (Info->Socket.is_open()) + { + Info->Socket.release(); + } + m_Sockets.erase(Fd); + } + return; + } + + if (!Info) + { + // New socket - wrap the native fd in an ASIO socket. + auto [It, Inserted] = m_Sockets.emplace(Fd, std::make_unique<SocketInfo>(m_IoContext)); + Info = It->second.get(); + + asio::error_code Ec; + // Determine protocol from the fd (v4 vs v6). Default to v4. + Info->Socket.assign(asio::ip::tcp::v4(), Fd, Ec); + if (Ec) + { + // Try v6 as fallback + Info->Socket.assign(asio::ip::tcp::v6(), Fd, Ec); + } + if (Ec) + { + ZEN_WARN("AsyncHttpClient: failed to assign socket fd {}: {}", static_cast<int>(Fd), Ec.message()); + m_Sockets.erase(Fd); + return; + } + + curl_multi_assign(m_Multi, Fd, Info); + } + + Info->WatchFlags = Action; + SetSocketWatch(Fd, Info); + } + + void SetSocketWatch(curl_socket_t Fd, SocketInfo* Info) + { + // Cancel any pending wait before issuing a new one. + Info->Socket.cancel(); + + if (Info->WatchFlags & CURL_POLL_IN) + { + Info->Socket.async_wait(asio::socket_base::wait_read, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_IN); + })); + } + + if (Info->WatchFlags & CURL_POLL_OUT) + { + Info->Socket.async_wait(asio::socket_base::wait_write, asio::bind_executor(m_Strand, [this, Fd](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + OnSocketReady(Fd, CURL_CSELECT_OUT); + })); + } + } + + void OnSocketReady(curl_socket_t Fd, int CurlAction) + { + ZEN_TRACE_CPU("AsyncHttpClient::OnSocketReady"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, Fd, CurlAction, &StillRunning); + CheckCompleted(); + + // Re-arm the watch if the socket is still tracked. + auto It = m_Sockets.find(Fd); + if (It != m_Sockets.end()) + { + SetSocketWatch(Fd, It->second.get()); + } + } + + // Called by curl when it wants a timeout ------------------------------ + + void OnCurlTimer(long TimeoutMs) + { + m_Timer.cancel(); + + if (TimeoutMs < 0) + { + // curl says "no timeout needed" + return; + } + + if (TimeoutMs == 0) + { + // curl wants immediate action - run it directly on the strand. + asio::post(m_Strand, [this]() { + if (m_ShuttingDown) + { + return; + } + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + }); + return; + } + + m_Timer.expires_after(std::chrono::milliseconds(TimeoutMs)); + m_Timer.async_wait(asio::bind_executor(m_Strand, [this](const asio::error_code& Ec) { + if (Ec || m_ShuttingDown) + { + return; + } + ZEN_TRACE_CPU("AsyncHttpClient::OnTimeout"); + int StillRunning = 0; + curl_multi_socket_action(m_Multi, CURL_SOCKET_TIMEOUT, 0, &StillRunning); + CheckCompleted(); + })); + } + + // Drain completed transfers from curl_multi -------------------------- + + void CheckCompleted() + { + int MsgsLeft = 0; + CURLMsg* Msg = nullptr; + while ((Msg = curl_multi_info_read(m_Multi, &MsgsLeft)) != nullptr) + { + if (Msg->msg != CURLMSG_DONE) + { + continue; + } + + CURL* Handle = Msg->easy_handle; + CURLcode Result = Msg->data.result; + + curl_multi_remove_handle(m_Multi, Handle); + + auto It = m_Transfers.find(Handle); + if (It == m_Transfers.end()) + { + ReleaseHandle(Handle); + continue; + } + + std::unique_ptr<TransferContext> Ctx = std::move(It->second); + m_Transfers.erase(It); + + CompleteTransfer(Handle, Result, std::move(Ctx)); + } + } + + void CompleteTransfer(CURL* Handle, CURLcode CurlResult, std::unique_ptr<TransferContext> Ctx) + { + ZEN_TRACE_CPU("AsyncHttpClient::CompleteTransfer"); + // Extract result info + long StatusCode = 0; + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + + ReleaseHandle(Handle); + + // Build response + HttpClient::Response Response; + Response.StatusCode = HttpResponseCode(StatusCode); + Response.UploadedBytes = static_cast<int64_t>(UpBytes); + Response.DownloadedBytes = static_cast<int64_t>(DownBytes); + Response.ElapsedSeconds = Elapsed; + Response.Header = BuildHeaderMap(Ctx->ResponseHeaders); + + if (CurlResult != CURLE_OK) + { + const char* ErrorMsg = curl_easy_strerror(CurlResult); + + if (CurlResult != CURLE_OPERATION_TIMEDOUT && CurlResult != CURLE_COULDNT_CONNECT && CurlResult != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("AsyncHttpClient failure: ({}) '{}'", static_cast<int>(CurlResult), ErrorMsg); + } + + if (!Ctx->Body.empty()) + { + Response.ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + } + + Response.Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(CurlResult), .ErrorMessage = std::string(ErrorMsg)}; + } + else if (StatusCode == static_cast<long>(HttpResponseCode::NoContent) || Ctx->Body.empty()) + { + // No payload + } + else + { + IoBuffer PayloadBuffer = IoBufferBuilder::MakeCloneFromMemory(Ctx->Body.data(), Ctx->Body.size()); + ApplyContentTypeFromHeaders(PayloadBuffer, Ctx->ResponseHeaders); + + const HttpResponseCode Code = HttpResponseCode(StatusCode); + if (!IsHttpSuccessCode(Code) && Code != HttpResponseCode::NotFound) + { + ZEN_WARN("AsyncHttpClient request failed: status={}, base={}", static_cast<int>(Code), m_BaseUri); + } + + Response.ResponsePayload = std::move(PayloadBuffer); + } + + // Dispatch the user callback off the strand so a slow callback + // cannot starve the curl_multi poll loop. + asio::post(m_IoContext, [LogRef = m_Log, Cb = std::move(Ctx->Callback), Response = std::move(Response)]() mutable { + try + { + Cb(std::move(Response)); + } + catch (const std::exception& Ex) + { + auto Log = [&]() -> LoggerRef { return LogRef; }; + ZEN_ERROR("AsyncHttpClient: unhandled exception in completion callback: {}", Ex.what()); + } + }); + } + + // -- Async verb implementations -------------------------------------- + + void DoAsyncGet(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Get"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_HTTPGET, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncHead(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Head"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_NOBODY, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncDelete(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, Url = std::move(Url), Callback = std::move(Callback), AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Delete"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_CUSTOMREQUEST, "DELETE"); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPost(std::string Url, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Post"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE, 0L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPostWithPayload(std::string Url, + IoBuffer Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + ContentType, + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::PostWithPayload"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, {}); + curl_easy_setopt(Handle, CURLOPT_POST, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = + BuildHeaderList(AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + // Set up read callback for payload data + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutWithPayload(std::string Url, + IoBuffer Payload, + AsyncHttpCallback Callback, + HttpClient::KeyValueMap AdditionalHeader, + HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, + [this, + Url = std::move(Url), + Payload = std::move(Payload), + Callback = std::move(Callback), + AdditionalHeader = std::move(AdditionalHeader), + Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + Ctx->PayloadBuffer = std::move(Payload); + Ctx->HeaderList = BuildHeaderList( + AdditionalHeader, + m_SessionId, + GetAccessToken(), + {std::make_pair("Content-Type", std::string(MapContentTypeToString(Ctx->PayloadBuffer.GetContentType())))}); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + Ctx->ReadData.DataPtr = static_cast<const uint8_t*>(Ctx->PayloadBuffer.GetData()); + Ctx->ReadData.DataSize = Ctx->PayloadBuffer.GetSize(); + Ctx->ReadData.Offset = 0; + + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Ctx->PayloadBuffer.GetSize())); + curl_easy_setopt(Handle, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(Handle, CURLOPT_READDATA, &Ctx->ReadData); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + void DoAsyncPutNoPayload(std::string Url, AsyncHttpCallback Callback, HttpClient::KeyValueMap Parameters) + { + asio::post(m_Strand, [this, Url = std::move(Url), Callback = std::move(Callback), Parameters = std::move(Parameters)]() mutable { + ZEN_TRACE_CPU("AsyncHttpClient::Put"); + if (m_ShuttingDown) + { + return; + } + CURL* Handle = AllocHandle(); + ConfigureHandle(Handle, Url, Parameters); + curl_easy_setopt(Handle, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(Handle, CURLOPT_INFILESIZE_LARGE, 0LL); + + auto Ctx = std::make_unique<TransferContext>(std::move(Callback)); + + HttpClient::KeyValueMap ContentLengthHeader{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Ctx->HeaderList = BuildHeaderList(ContentLengthHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, Ctx->HeaderList); + + SubmitTransfer(Handle, std::move(Ctx)); + }); + } + + // -- Members --------------------------------------------------------- + + std::string m_BaseUri; + HttpClientSettings m_Settings; + LoggerRef m_Log; + std::string m_SessionId; + std::string m_UnixSocketPathUtf8; + + // io_context and strand - all curl_multi operations are serialized on the + // strand, making this safe even when the io_context has multiple threads. + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + asio::strand<asio::io_context::executor_type> m_Strand; + std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; + std::thread m_IoThread; + + // curl_multi and socket-action state + CURLM* m_Multi = nullptr; + std::unordered_map<CURL*, std::unique_ptr<TransferContext>> m_Transfers; + std::vector<CURL*> m_HandlePool; + std::unordered_map<curl_socket_t, std::unique_ptr<SocketInfo>> m_Sockets; + asio::steady_timer m_Timer; + bool m_ShuttingDown = false; + + // Access token cache + RwLock m_AccessTokenLock; + HttpClientAccessToken m_CachedAccessToken; +}; + +////////////////////////////////////////////////////////////////////////// +// +// AsyncHttpClient public API + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, Settings)) +{ +} + +AsyncHttpClient::AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(BaseUri, IoContext, Settings)) +{ +} + +AsyncHttpClient::~AsyncHttpClient() = default; + +// -- Callback-based API -------------------------------------------------- + +void +AsyncHttpClient::AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncGet(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncHead(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncDelete(std::string(Url), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPost(std::string(Url), std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, Payload.GetContentType(), std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader) +{ + m_Impl->DoAsyncPostWithPayload(std::string(Url), Payload, ContentType, std::move(Callback), AdditionalHeader); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutWithPayload(std::string(Url), Payload, std::move(Callback), AdditionalHeader, Parameters); +} + +void +AsyncHttpClient::AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters) +{ + m_Impl->DoAsyncPutNoPayload(std::string(Url), std::move(Callback), Parameters); +} + +// -- Future-based API ---------------------------------------------------- + +std::future<HttpClient::Response> +AsyncHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncGet( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncHead( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncDelete( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPost( + Url, + Payload, + ContentType, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + Payload, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + AdditionalHeader, + Parameters); + return Future; +} + +std::future<HttpClient::Response> +AsyncHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + auto Promise = std::make_shared<std::promise<Response>>(); + auto Future = Promise->get_future(); + AsyncPut( + Url, + [Promise](Response R) { Promise->set_value(std::move(R)); }, + Parameters); + return Future; +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index d150b44c6..56b9c39c5 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpclientcurl.h" +#include "httpclientcurlhelpers.h" #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> @@ -29,153 +30,7 @@ static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; ////////////////////////////////////////////////////////////////////////// -static HttpClientErrorCode -MapCurlError(CURLcode Code) -{ - switch (Code) - { - case CURLE_OK: - return HttpClientErrorCode::kOK; - case CURLE_COULDNT_CONNECT: - return HttpClientErrorCode::kConnectionFailure; - case CURLE_COULDNT_RESOLVE_HOST: - return HttpClientErrorCode::kHostResolutionFailure; - case CURLE_COULDNT_RESOLVE_PROXY: - return HttpClientErrorCode::kProxyResolutionFailure; - case CURLE_RECV_ERROR: - return HttpClientErrorCode::kNetworkReceiveError; - case CURLE_SEND_ERROR: - return HttpClientErrorCode::kNetworkSendFailure; - case CURLE_OPERATION_TIMEDOUT: - return HttpClientErrorCode::kOperationTimedOut; - case CURLE_SSL_CONNECT_ERROR: - return HttpClientErrorCode::kSSLConnectError; - case CURLE_SSL_CERTPROBLEM: - return HttpClientErrorCode::kSSLCertificateError; - case CURLE_PEER_FAILED_VERIFICATION: - return HttpClientErrorCode::kSSLCACertError; - case CURLE_SSL_CIPHER: - case CURLE_SSL_ENGINE_NOTFOUND: - case CURLE_SSL_ENGINE_SETFAILED: - return HttpClientErrorCode::kGenericSSLError; - case CURLE_ABORTED_BY_CALLBACK: - return HttpClientErrorCode::kRequestCancelled; - default: - return HttpClientErrorCode::kOtherError; - } -} - -////////////////////////////////////////////////////////////////////////// -// -// Curl callback helpers - -struct WriteCallbackData -{ - std::string* Body = nullptr; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<WriteCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return 0; // Signal abort to curl - } - - Data->Body->append(Ptr, TotalBytes); - return TotalBytes; -} - -struct HeaderCallbackData -{ - std::vector<std::pair<std::string, std::string>>* Headers = nullptr; -}; - -// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. -// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). -static std::optional<std::pair<std::string_view, std::string_view>> -ParseHeaderLine(std::string_view Line) -{ - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) - { - return std::nullopt; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos == std::string_view::npos) - { - return std::nullopt; - } - - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - - return std::pair{Key, Value}; -} - -static size_t -CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<HeaderCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) - { - auto& [Key, Value] = *Header; - Data->Headers->emplace_back(std::string(Key), std::string(Value)); - } - - return TotalBytes; -} - -struct ReadCallbackData -{ - const uint8_t* DataPtr = nullptr; - size_t DataSize = 0; - size_t Offset = 0; - std::function<bool()>* CheckIfAbortFunction = nullptr; -}; - -static size_t -CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) -{ - auto* Data = static_cast<ReadCallbackData*>(UserData); - size_t MaxRead = Size * Nmemb; - - if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) - { - return CURL_READFUNC_ABORT; - } - - size_t Remaining = Data->DataSize - Data->Offset; - size_t ToRead = std::min(MaxRead, Remaining); - - if (ToRead > 0) - { - memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); - Data->Offset += ToRead; - } - - return ToRead; -} +// Curl callback helpers and shared utilities are in httpclientcurlhelpers.h struct StreamReadCallbackData { @@ -281,120 +136,6 @@ CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, voi ////////////////////////////////////////////////////////////////////////// -static std::pair<std::string, std::string> -HeaderContentType(ZenContentType ContentType) -{ - return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); -} - -static curl_slist* -BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, - std::string_view SessionId, - const std::optional<std::string>& AccessToken, - const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) -{ - curl_slist* Headers = nullptr; - - for (const auto& [Key, Value] : *AdditionalHeader) - { - ExtendableStringBuilder<64> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - if (!SessionId.empty()) - { - ExtendableStringBuilder<64> SessionHeader; - SessionHeader << "UE-Session: " << SessionId; - Headers = curl_slist_append(Headers, SessionHeader.c_str()); - } - - if (AccessToken.has_value()) - { - ExtendableStringBuilder<128> AuthHeader; - AuthHeader << "Authorization: " << AccessToken.value(); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); - } - - for (const auto& [Key, Value] : ExtraHeaders) - { - ExtendableStringBuilder<128> HeaderLine; - HeaderLine << Key << ": " << Value; - Headers = curl_slist_append(Headers, HeaderLine.c_str()); - } - - return Headers; -} - -static HttpClient::KeyValueMap -BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) -{ - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HeaderMap; -} - -// Scans response headers for Content-Type and applies it to the buffer. -static void -ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) -{ - for (const auto& [Key, Value] : Headers) - { - if (StrCaseCompare(Key, "Content-Type") == 0) - { - Buffer.SetContentType(ParseContentType(Value)); - break; - } - } -} - -static void -AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) -{ - static constexpr char HexDigits[] = "0123456789ABCDEF"; - static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); - - for (char C : Input) - { - if (Unreserved.Contains(C)) - { - Out.Append(C); - } - else - { - uint8_t Byte = static_cast<uint8_t>(C); - char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; - Out.Append(std::string_view(Encoded, 3)); - } - } -} - -static void -BuildUrlWithParameters(StringBuilderBase& Url, - std::string_view BaseUrl, - std::string_view ResourcePath, - const HttpClient::KeyValueMap& Parameters) -{ - Url.Append(BaseUrl); - Url.Append(ResourcePath); - - if (!Parameters->empty()) - { - char Separator = '?'; - for (const auto& [Key, Value] : *Parameters) - { - Url.Append(Separator); - AppendUrlEncoded(Url, Key); - Url.Append('='); - AppendUrlEncoded(Url, Value); - Separator = '&'; - } - } -} - ////////////////////////////////////////////////////////////////////////// CurlHttpClient::CurlHttpClient(std::string_view BaseUri, @@ -440,9 +181,9 @@ CurlHttpClient::CurlResult CurlHttpClient::Session::PerformWithResponseCallbacks() { std::string Body; - WriteCallbackData WriteData{.Body = &Body, + CurlWriteCallbackData WriteData{.Body = &Body, .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; - HeaderCallbackData HdrData{}; + CurlHeaderCallbackData HdrData{}; std::vector<std::pair<std::string, std::string>> ResponseHeaders; HdrData.Headers = &ResponseHeaders; @@ -487,6 +228,13 @@ CurlHttpClient::Session::Perform() curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + char* EffectiveUrl = nullptr; + curl_easy_getinfo(Handle, CURLINFO_EFFECTIVE_URL, &EffectiveUrl); + if (EffectiveUrl) + { + Result.Url = EffectiveUrl; + } + return Result; } @@ -553,8 +301,9 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) { - ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + ZEN_WARN("HttpClient client failure (session: {}, url: {}): ({}) '{}'", SessionId, + Result.Url, static_cast<int>(Result.ErrorCode), Result.ErrorMessage); } @@ -702,6 +451,7 @@ CurlHttpClient::ShouldRetry(const CurlResult& Result) case CURLE_RECV_ERROR: case CURLE_SEND_ERROR: case CURLE_OPERATION_TIMEDOUT: + case CURLE_PARTIAL_FILE: return true; default: return false; @@ -748,10 +498,11 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult { if (Result.ErrorCode != CURLE_OK) { - ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' (Curl error: {}) Attempt {}/{}", SessionId, static_cast<int>(MapCurlError(Result.ErrorCode)), Result.ErrorMessage, + static_cast<int>(Result.ErrorCode), Attempt, m_ConnectionSettings.RetryCount + 1); } @@ -998,9 +749,9 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); @@ -1367,9 +1118,9 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV return Sess.PerformWithResponseCallbacks(); } - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + CurlReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index bdeb46633..ea9193e65 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -73,6 +73,7 @@ private: int64_t DownloadedBytes = 0; CURLcode ErrorCode = CURLE_OK; std::string ErrorMessage; + std::string Url; }; struct Session diff --git a/src/zenhttp/clients/httpclientcurlhelpers.h b/src/zenhttp/clients/httpclientcurlhelpers.h new file mode 100644 index 000000000..0605a30f6 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurlhelpers.h @@ -0,0 +1,293 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Shared helpers for curl-based HTTP client implementations (sync and async). +// This is an internal header, not part of the public API. + +#include <zencore/string.h> + +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <optional> +#include <string> +#include <utility> +#include <vector> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Error mapping + +inline HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback data structures and callbacks + +struct CurlWriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct CurlHeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +inline std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) +{ + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return std::nullopt; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +inline size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct CurlReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +inline size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<CurlReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +////////////////////////////////////////////////////////////////////////// +// +// URL and header construction + +inline void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +inline void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) +{ + Url.Append(BaseUrl); + Url.Append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); + Separator = '&'; + } + } +} + +inline std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +inline curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<std::string>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken.has_value()) + { + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken.value(); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +inline HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +inline void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index ace7a3c7f..3da8a9220 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -520,7 +520,7 @@ MeasureLatency(HttpClient& Client, std::string_view Url) ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. - // Bail out immediately — retrying will just burn the connect timeout each time. + // Bail out immediately - retrying will just burn the connect timeout each time. if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) { break; diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index af653cbb2..67cbaea9b 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -194,7 +194,7 @@ public: "slow", [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { - Sleep(2000); + Sleep(100); Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); }); }, @@ -750,7 +750,9 @@ TEST_CASE("httpclient.error-handling") { SUBCASE("Connection refused") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); HttpClient::Response Resp = Client.Get("/api/test/hello"); CHECK(!Resp.IsSuccess()); CHECK(Resp.Error.has_value()); @@ -760,7 +762,7 @@ TEST_CASE("httpclient.error-handling") { TestServerFixture Fixture; HttpClientSettings Settings; - Settings.Timeout = std::chrono::milliseconds(500); + Settings.Timeout = std::chrono::milliseconds(50); HttpClient Client = Fixture.MakeClient(Settings); HttpClient::Response Resp = Client.Get("/api/test/slow"); @@ -970,7 +972,9 @@ TEST_CASE("httpclient.measurelatency") SUBCASE("Failed measurement against unreachable port") { - HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClientSettings Settings; + Settings.ConnectTimeout = std::chrono::milliseconds(200); + HttpClient Client("127.0.0.1:19999", Settings, /*CheckIfAbortFunction*/ {}); LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); CHECK(!Result.Success); CHECK(!Result.FailureReason.empty()); @@ -1144,7 +1148,7 @@ struct FaultTcpServer ~FaultTcpServer() { // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this - // thread — ASIO I/O objects are not safe for concurrent access and the io_context + // thread - ASIO I/O objects are not safe for concurrent access and the io_context // thread may be touching the acceptor in StartAccept(). m_IoContext.stop(); if (m_Thread.joinable()) @@ -1498,7 +1502,7 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) std::atomic<bool> StallActive{true}; FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { DrainHttpRequest(Socket); - // Stop reading body — TCP window will fill and client send will stall + // Stop reading body - TCP window will fill and client send will stall while (StallActive.load()) { std::this_thread::sleep_for(std::chrono::milliseconds(50)); @@ -1735,21 +1739,21 @@ TEST_CASE("httpclient.uri_decoding") TestServerFixture Fixture; HttpClient Client = Fixture.MakeClient(); - // URI without encoding — should pass through unchanged + // URI without encoding - should pass through unchanged { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt"); } - // Percent-encoded space — server should see decoded path + // Percent-encoded space - server should see decoded path { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt"); } - // Percent-encoded slash (%2F) — should be decoded to / + // Percent-encoded slash (%2F) - should be decoded to / { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt"); REQUIRE(Resp.IsSuccess()); @@ -1763,21 +1767,21 @@ TEST_CASE("httpclient.uri_decoding") CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt"); } - // No capture — echo/uri route returns just RelativeUri + // No capture - echo/uri route returns just RelativeUri { HttpClient::Response Resp = Client.Get("/api/test/echo/uri"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "echo/uri"); } - // Literal percent that is not an escape (%ZZ) — should be kept as-is + // Literal percent that is not an escape (%ZZ) - should be kept as-is { HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt"); REQUIRE(Resp.IsSuccess()); CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt"); } - // Query params — raw values are returned as-is from GetQueryParams + // Query params - raw values are returned as-is from GetQueryParams { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test"); REQUIRE(Resp.IsSuccess()); @@ -1788,7 +1792,7 @@ TEST_CASE("httpclient.uri_decoding") { HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3"); REQUIRE(Resp.IsSuccess()); - // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly + // GetQueryParams returns raw (still-encoded) values - callers must Decode() explicitly CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3"); } diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index c42841922..0432e50ef 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -94,7 +94,8 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Unattended, bool Quiet, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { Stopwatch Timer; @@ -117,8 +118,9 @@ namespace zen { namespace httpclientauth { } }); - const std::string ProcArgs = fmt::format("{} --AuthConfigUrl {} --OutFile {} --Unattended={}", + const std::string ProcArgs = fmt::format("{} {} {} --OutFile {} --Unattended={}", OidcExecutablePath, + IsHordeUrl ? "--HordeUrl" : "--AuthConfigUrl", CloudHost, AuthTokenPath, Unattended ? "true"sv : "false"sv); @@ -193,7 +195,7 @@ namespace zen { namespace httpclientauth { } else { - ZEN_WARN("Failed running {} to get auth token, error code {}", OidcExecutablePath, ExitCode); + ZEN_WARN("Failed running '{}' to get auth token, error code {}", ProcArgs, ExitCode); } return HttpClientAccessToken{}; } @@ -202,9 +204,10 @@ namespace zen { namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden) + bool Hidden, + bool IsHordeUrl) { - HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + HttpClientAccessToken InitialToken = GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); if (InitialToken.IsValid()) { return [OidcExecutablePath = std::filesystem::path(OidcExecutablePath), @@ -212,12 +215,13 @@ namespace zen { namespace httpclientauth { Token = InitialToken, Quiet, Unattended, - Hidden]() mutable { + Hidden, + IsHordeUrl]() mutable { if (!Token.NeedsRefresh()) { return std::move(Token); } - return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden); + return GetOidcTokenFromExe(OidcExecutablePath, CloudHost, Unattended, Quiet, Hidden, IsHordeUrl); }; } return {}; diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index e05c9815f..03117ee6c 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -266,10 +266,10 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return false; } - const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); - const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); + const auto Start = ParseInt<uint64_t>(Token.substr(0, Delim)); + const auto End = ParseInt<uint64_t>(Token.substr(Delim + 1)); - if (Start.has_value() && End.has_value() && End.value() > Start.value()) + if (Start.has_value() && End.has_value() && End.value() >= Start.value()) { Ranges.push_back({.Start = Start.value(), .End = End.value()}); } @@ -286,6 +286,45 @@ TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) return Count != Ranges.size(); } +MultipartByteRangesResult +BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges) +{ + Oid::String_t BoundaryStr; + Oid::NewOid().ToString(BoundaryStr); + std::string_view Boundary(BoundaryStr, Oid::StringLength); + + const uint64_t TotalSize = Data.GetSize(); + + std::vector<IoBuffer> Parts; + Parts.reserve(Ranges.size() * 2 + 1); + + for (const HttpRange& Range : Ranges) + { + uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + return {}; + } + + uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + + std::string PartHeader = fmt::format("\r\n--{}\r\nContent-Type: application/octet-stream\r\nContent-Range: bytes {}-{}/{}\r\n\r\n", + Boundary, + Range.Start, + RangeEnd, + TotalSize); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(PartHeader.data(), PartHeader.size())); + + IoBuffer RangeData(Data, Range.Start, RangeSize); + Parts.push_back(RangeData); + } + + std::string ClosingBoundary = fmt::format("\r\n--{}--", Boundary); + Parts.push_back(IoBufferBuilder::MakeCloneFromMemory(ClosingBoundary.data(), ClosingBoundary.size())); + + return {.Parts = std::move(Parts), .ContentType = fmt::format("multipart/byteranges; boundary={}", Boundary)}; +} + ////////////////////////////////////////////////////////////////////////// const std::string_view @@ -479,6 +518,18 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) return Ref<IHttpPackageHandler>(); } +bool +HttpService::AcceptsLocalFileReferences() const +{ + return false; +} + +const ILocalRefPolicy* +HttpService::GetLocalRefPolicy() const +{ + return nullptr; +} + ////////////////////////////////////////////////////////////////////////// HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) @@ -552,6 +603,56 @@ HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType } void +HttpServerRequest::WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges) +{ + if (Ranges.empty()) + { + WriteResponse(HttpResponseCode::OK, ContentType, IoBuffer(Data)); + return; + } + + if (Ranges.size() == 1) + { + const HttpRange& Range = Ranges[0]; + const uint64_t TotalSize = Data.GetSize(); + // ~uint64_t(0) is the sentinel meaning "end of file" (suffix range). + const uint64_t RangeEnd = (Range.End != ~uint64_t(0)) ? Range.End : TotalSize - 1; + + if (RangeEnd >= TotalSize || Range.Start > RangeEnd) + { + m_ContentRangeHeader = fmt::format("bytes */{}", TotalSize); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + + const uint64_t RangeSize = 1 + (RangeEnd - Range.Start); + IoBuffer RangeBuf(Data, Range.Start, RangeSize); + + m_ContentRangeHeader = fmt::format("bytes {}-{}/{}", Range.Start, RangeEnd, TotalSize); + WriteResponse(HttpResponseCode::PartialContent, ContentType, std::move(RangeBuf)); + return; + } + + // Multi-range + MultipartByteRangesResult MultipartResult = BuildMultipartByteRanges(Data, Ranges); + if (MultipartResult.Parts.empty()) + { + m_ContentRangeHeader = fmt::format("bytes */{}", Data.GetSize()); + WriteResponse(HttpResponseCode::RangeNotSatisfiable); + return; + } + WriteResponse(HttpResponseCode::PartialContent, std::move(MultipartResult.ContentType), std::span<IoBuffer>(MultipartResult.Parts)); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(ParseContentType(CustomContentType) == HttpContentType::kUnknownContentType); + m_ContentTypeOverride = CustomContentType; + WriteResponse(ResponseCode, HttpContentType::kBinary, Blobs); +} + +void HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) { std::span<const SharedBuffer> Segments = Payload.GetSegments(); @@ -705,7 +806,10 @@ HttpServerRequest::ReadPayloadPackage() { if (IoBuffer Payload = ReadPayload()) { - return ParsePackageMessage(std::move(Payload)); + ParseFlags Flags = + (IsLocalMachineRequest() && m_Service.AcceptsLocalFileReferences()) ? ParseFlags::kAllowLocalReferences : ParseFlags::kDefault; + const ILocalRefPolicy* Policy = EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences) ? m_Service.GetLocalRefPolicy() : nullptr; + return ParsePackageMessage(std::move(Payload), {}, Flags, Policy); } return {}; @@ -816,7 +920,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) // Strip the separator slash left over after the service prefix is removed. // When a service has BaseUri "/foo", the prefix length is set to len("/foo") = 4. - // Stripping 4 chars from "/foo/bar" yields "/bar" — the path separator becomes + // Stripping 4 chars from "/foo/bar" yields "/bar" - the path separator becomes // the first character of the relative URI. Remove it so patterns like "bar" or // "{id}" match without needing to account for the leading slash. if (!Uri.empty() && Uri.front() == '/') @@ -1273,7 +1377,12 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP return PackageHandlerRef->CreateTarget(Cid, Size); }; - CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + ParseFlags PkgFlags = (Request.IsLocalMachineRequest() && Service.AcceptsLocalFileReferences()) + ? ParseFlags::kAllowLocalReferences + : ParseFlags::kDefault; + const ILocalRefPolicy* PkgPolicy = + EnumHasAllFlags(PkgFlags, ParseFlags::kAllowLocalReferences) ? Service.GetLocalRefPolicy() : nullptr; + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer, PkgFlags, PkgPolicy); PackageHandlerRef->OnRequestComplete(); } @@ -1512,7 +1621,7 @@ TEST_CASE("http.common") }, HttpVerb::kGet); - // Single-segment literal with leading slash — simulates real server RelativeUri + // Single-segment literal with leading slash - simulates real server RelativeUri { Reset(); TestHttpServerRequest req{Service, "/activity_counters"sv}; @@ -1532,7 +1641,7 @@ TEST_CASE("http.common") CHECK_EQ(Captures[0], "hello"sv); } - // Two-segment route with leading slash — first literal segment + // Two-segment route with leading slash - first literal segment { Reset(); TestHttpServerRequest req{Service, "/prefix/world"sv}; diff --git a/src/zenhttp/include/zenhttp/asynchttpclient.h b/src/zenhttp/include/zenhttp/asynchttpclient.h new file mode 100644 index 000000000..cb41626b9 --- /dev/null +++ b/src/zenhttp/include/zenhttp/asynchttpclient.h @@ -0,0 +1,123 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> + +#include <functional> +#include <future> +#include <memory> + +namespace asio { +class io_context; +} + +namespace zen { + +/// Completion callback for async HTTP operations. +using AsyncHttpCallback = std::function<void(HttpClient::Response)>; + +/** Asynchronous HTTP client backed by curl_multi and ASIO. + * + * Uses curl_multi_socket_action() driven by ASIO socket async_wait to process + * transfers without blocking the caller. All curl_multi operations are + * serialized on an internal strand; callers may issue requests from any + * thread, and the io_context may have multiple threads. + * + * Two construction modes: + * - Owned io_context: creates an internal thread (self-contained). + * - External io_context: caller runs the event loop. + * + * Completion callbacks are dispatched on the io_context (not the internal + * strand), so a slow callback will not block the curl poll loop. Future- + * based wrappers (Get, Post, ...) return a std::future<Response> for + * callers that prefer blocking on a result. + */ +class AsyncHttpClient +{ +public: + using Response = HttpClient::Response; + using KeyValueMap = HttpClient::KeyValueMap; + + /// Construct with an internally-owned io_context and thread. + explicit AsyncHttpClient(std::string_view BaseUri, const HttpClientSettings& Settings = {}); + + /// Construct with an externally-managed io_context. The io_context must + /// outlive this client and must be running (via run()) on at least one thread. + AsyncHttpClient(std::string_view BaseUri, asio::io_context& IoContext, const HttpClientSettings& Settings = {}); + + ~AsyncHttpClient(); + + AsyncHttpClient(const AsyncHttpClient&) = delete; + AsyncHttpClient& operator=(const AsyncHttpClient&) = delete; + + // -- Callback-based API ---------------------------------------------- + + void AsyncGet(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncHead(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncDelete(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPost(std::string_view Url, const IoBuffer& Payload, AsyncHttpCallback Callback, const KeyValueMap& AdditionalHeader = {}); + + void AsyncPost(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}); + + void AsyncPut(std::string_view Url, + const IoBuffer& Payload, + AsyncHttpCallback Callback, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + void AsyncPut(std::string_view Url, AsyncHttpCallback Callback, const KeyValueMap& Parameters = {}); + + // -- Future-based API ------------------------------------------------ + + [[nodiscard]] std::future<Response> Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, + const IoBuffer& Payload, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}); + + [[nodiscard]] std::future<Response> Put(std::string_view Url, const KeyValueMap& Parameters = {}); + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +void asynchttpclient_test_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpclientauth.h b/src/zenhttp/include/zenhttp/httpclientauth.h index ce646ebd7..9220a50b6 100644 --- a/src/zenhttp/include/zenhttp/httpclientauth.h +++ b/src/zenhttp/include/zenhttp/httpclientauth.h @@ -33,7 +33,8 @@ namespace httpclientauth { std::string_view CloudHost, bool Quiet, bool Unattended, - bool Hidden); + bool Hidden, + bool IsHordeUrl = false); } // namespace httpclientauth } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index f9a99f3cc..1d921600d 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -19,8 +19,8 @@ class StringBuilderBase; struct HttpRange { - uint32_t Start = ~uint32_t(0); - uint32_t End = ~uint32_t(0); + uint64_t Start = ~uint64_t(0); + uint64_t End = ~uint64_t(0); }; using HttpRanges = std::vector<HttpRange>; @@ -30,6 +30,16 @@ extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeSt std::string_view ReasonStringForHttpResultCode(int HttpCode); bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); +struct MultipartByteRangesResult +{ + std::vector<IoBuffer> Parts; + std::string ContentType; +}; + +// Build a multipart/byteranges response body from the given data and ranges. +// Generates a unique boundary per call. Returns empty Parts if any range is out of bounds. +MultipartByteRangesResult BuildMultipartByteRanges(const IoBuffer& Data, const HttpRanges& Ranges); + enum class HttpVerb : uint8_t { kGet = 1 << 0, diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 5eaed6004..955b8ed15 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -12,6 +12,7 @@ #include <zencore/string.h> #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zenhttp/localrefpolicy.h> #include <zentelemetry/hyperloglog.h> #include <zentelemetry/stats.h> @@ -121,11 +122,13 @@ public: virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + void WriteResponse(HttpResponseCode ResponseCode, const std::string& CustomContentType, std::span<IoBuffer> Blobs); void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + void WriteResponse(HttpContentType ContentType, const IoBuffer& Data, const HttpRanges& Ranges); virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; @@ -151,6 +154,8 @@ protected: std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); mutable Oid m_SessionId = Oid::Zero; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; inline void SetIsHandled() { m_Flags |= kIsHandled; } @@ -193,9 +198,16 @@ public: HttpService() = default; virtual ~HttpService() = default; - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + [[nodiscard]] virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + /// Whether this service accepts local file references in inbound packages from local clients. + [[nodiscard]] virtual bool AcceptsLocalFileReferences() const; + + /// Returns the local ref policy for validating file paths in inbound local references. + /// Returns nullptr by default, which causes file-path local refs to be rejected (fail-closed). + [[nodiscard]] virtual const ILocalRefPolicy* GetLocalRefPolicy() const; // Internals @@ -290,12 +302,12 @@ public: std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } - /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + /** Track active WebSocket connections - called by server implementations on upgrade/close. */ void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } - /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + /** Track WebSocket frame and byte counters - called by WS connection implementations per frame. */ void OnWebSocketFrameReceived(uint64_t Bytes) { m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); @@ -317,7 +329,7 @@ private: int m_EffectiveHttpsPort = 0; std::string m_ExternalHost; metrics::Meter m_RequestMeter; - metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error — sufficient for client counting + metrics::HyperLogLog<12> m_ClientAddresses; // ~4 KiB, ~1.6% error - sufficient for client counting metrics::HyperLogLog<12> m_ClientSessions; std::string m_DefaultRedirect; std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; @@ -510,7 +522,8 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -void http_forcelink(); // internal -void websocket_forcelink(); // internal +void http_forcelink(); // internal +void httpparser_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index bce771c75..51ab2e06e 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -23,11 +23,11 @@ namespace zen { class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - /// Construct without an io_context — optionally uses a dedicated push thread + /// Construct without an io_context - optionally uses a dedicated push thread /// for WebSocket stats broadcasting. explicit HttpStatsService(bool EnableWebSockets = false); - /// Construct with an external io_context — uses an asio timer instead + /// Construct with an external io_context - uses an asio timer instead /// of a dedicated thread for WebSocket stats broadcasting. /// The caller must ensure the io_context outlives this service and that /// its run loop is active. @@ -43,7 +43,7 @@ public: virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override; void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 9c3b909a2..fd2f79171 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -26,7 +26,7 @@ namespace zen { * Callback interface for WebSocket client events * * Separate from the server-side IWebSocketHandler because the caller - * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + * already owns the HttpWsClient - no Ref<WebSocketConnection> needed. */ class IWsClientHandler { @@ -85,9 +85,9 @@ private: /// it is treated as a plain host:port and gets the ws:// prefix. /// /// Examples: -/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws" -/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo" -/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar" +/// HttpToWsUrl("http://host:8080", "/orch/ws") -> "ws://host:8080/orch/ws" +/// HttpToWsUrl("https://host", "/foo") -> "wss://host/foo" +/// HttpToWsUrl("host:8080", "/bar") -> "ws://host:8080/bar" std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/localrefpolicy.h b/src/zenhttp/include/zenhttp/localrefpolicy.h new file mode 100644 index 000000000..0b37f9dc7 --- /dev/null +++ b/src/zenhttp/include/zenhttp/localrefpolicy.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> + +namespace zen { + +/// Policy interface for validating local file reference paths in inbound CbPackage messages. +/// Implementations should throw std::invalid_argument if the path is not allowed. +class ILocalRefPolicy +{ +public: + virtual ~ILocalRefPolicy() = default; + + /// Validate that a local file reference path is allowed. + /// Throws std::invalid_argument if the path escapes the allowed root. + virtual void ValidatePath(const std::filesystem::path& Path) const = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index 1a5068580..66e3f6e55 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -5,6 +5,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> +#include <zenhttp/localrefpolicy.h> #include <functional> #include <gsl/gsl-lite.hpp> @@ -97,11 +98,22 @@ gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle = nullptr); -CbPackage ParsePackageMessage( - IoBuffer Payload, - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + +enum class ParseFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), // Allow packages containing local file references (local clients only) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(ParseFlags); + +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; - }); + }, + ParseFlags Flags = ParseFlags::kDefault, + const ILocalRefPolicy* Policy = nullptr); bool IsPackageMessage(IoBuffer Payload); bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); @@ -122,10 +134,11 @@ CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, void* Targe class CbPackageReader { public: - CbPackageReader(); + CbPackageReader(ParseFlags Flags = ParseFlags::kDefault); ~CbPackageReader(); void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + void SetLocalRefPolicy(const ILocalRefPolicy* Policy); /** Process compact binary package data stream @@ -149,6 +162,8 @@ private: kReadingBuffers } m_CurrentState = State::kInitialState; + ParseFlags m_Flags; + const ILocalRefPolicy* m_LocalRefPolicy = nullptr; std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; std::vector<IoBuffer> m_PayloadBuffers; std::vector<CbAttachmentEntry> m_AttachmentEntries; diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 710579faa..2d25515d3 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -59,7 +59,7 @@ class IWebSocketHandler public: virtual ~IWebSocketHandler() = default; - virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) = 0; virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; }; diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index 7e6207e56..5ad5ebcc7 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -196,8 +196,9 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) // void -HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) { + ZEN_UNUSED(RelativeUri); ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); ZEN_INFO("Stats WebSocket client connected"); diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 9c62c1f2d..267ce386c 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -36,6 +36,71 @@ const std::string_view HandlePrefix(":?#:"); typedef eastl::fixed_vector<IoBuffer, 16> IoBufferVec_t; +/// Enforce local-ref path policy. Handle-based refs bypass the policy since they use OS handle security. +/// If no policy is set, file-path local refs are rejected (fail-closed). +static void +ApplyLocalRefPolicy(const ILocalRefPolicy* Policy, const std::filesystem::path& Path) +{ + if (Policy) + { + Policy->ValidatePath(Path); + } + else + { + throw std::invalid_argument("local file reference rejected: no validation policy"); + } +} + +// Validates the CbPackageHeader magic and attachment count. Returns the total +// chunk count (AttachmentCount + 1, including the implicit root object). +static uint32_t +ValidatePackageHeader(const CbPackageHeader& Hdr) +{ + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::invalid_argument( + fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr.HeaderMagic)); + } + // ChunkCount is AttachmentCount + 1 (the root object is implicit). Guard against + // UINT32_MAX wrapping to 0, which would bypass subsequent size checks. + if (Hdr.AttachmentCount == UINT32_MAX) + { + throw std::invalid_argument("invalid CbPackage, attachment count overflow"); + } + return Hdr.AttachmentCount + 1; +} + +struct ValidatedLocalRef +{ + bool Valid = false; + const CbAttachmentReferenceHeader* Header = nullptr; + std::string_view Path; + std::string Error; +}; + +// Validates that the attachment buffer contains a well-formed local reference +// header and path. On failure, Valid is false and Error contains the reason. +static ValidatedLocalRef +ValidateLocalRef(const IoBuffer& AttachmentBuffer) +{ + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader)) + { + return {.Error = fmt::format("local ref attachment too small for header (size {})", AttachmentBuffer.Size())}; + } + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + + if (AttachmentBuffer.Size() < sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength) + { + return {.Error = fmt::format("local ref attachment too small for path (need {}, have {})", + sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength, + AttachmentBuffer.Size())}; + } + + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + return {.Valid = true, .Header = AttachRefHdr, .Path = std::string_view(PathPointer, AttachRefHdr->AbsolutePathLength)}; +} + IoBufferVec_t FormatPackageMessageInternal(const CbPackage& Data, FormatFlags Flags, void* TargetProcessHandle); std::vector<IoBuffer> @@ -361,7 +426,10 @@ IsPackageMessage(IoBuffer Payload) } CbPackage -ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +ParsePackageMessage(IoBuffer Payload, + std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer, + ParseFlags Flags, + const ILocalRefPolicy* Policy) { ZEN_TRACE_CPU("ParsePackageMessage"); @@ -372,17 +440,13 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint BinaryReader Reader(Payload); - const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); - if (Hdr->HeaderMagic != kCbPkgMagic) - { - throw std::invalid_argument( - fmt::format("invalid CbPackage header magic, expected {0:x}, got {1:x}", static_cast<uint32_t>(kCbPkgMagic), Hdr->HeaderMagic)); - } + const CbPackageHeader* Hdr = reinterpret_cast<const CbPackageHeader*>(Reader.GetView(sizeof(CbPackageHeader)).GetData()); + const uint32_t ChunkCount = ValidatePackageHeader(*Hdr); Reader.Skip(sizeof(CbPackageHeader)); - const uint32_t ChunkCount = Hdr->AttachmentCount + 1; - - if (Reader.Remaining() < sizeof(CbAttachmentEntry) * ChunkCount) + // Widen to uint64_t so the multiplication cannot wrap on 32-bit. + const uint64_t AttachmentTableSize = uint64_t(sizeof(CbAttachmentEntry)) * ChunkCount; + if (Reader.Remaining() < AttachmentTableSize) { throw std::invalid_argument(fmt::format("invalid CbPackage, missing attachment entry data (need {} bytes, have {} bytes)", sizeof(CbAttachmentEntry) * ChunkCount, @@ -417,15 +481,22 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) { - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + if (!EnumHasAllFlags(Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", i)); + } - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + MalformedAttachments.push_back(std::make_pair(i, fmt::format("{} for {}", LocalRef.Error, Entry.AttachmentHash))); + continue; + } + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::string_view PathView = LocalRef.Path; IoBuffer FullFileBuffer; @@ -461,13 +532,29 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else { + ApplyLocalRefPolicy(Policy, Path); FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; } } if (FullFileBuffer) { - IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + // Guard against offset+size overflow or exceeding the file bounds. + const uint64_t FileSize = FullFileBuffer.GetSize(); + if (AttachRefHdr->PayloadByteOffset > FileSize || + AttachRefHdr->PayloadByteSize > FileSize - AttachRefHdr->PayloadByteOffset) + { + MalformedAttachments.push_back( + std::make_pair(i, + fmt::format("Local ref offset/size out of bounds (offset {}, size {}, file size {}) for {}", + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + FileSize, + Entry.AttachmentHash))); + continue; + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FileSize ? FullFileBuffer : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -630,7 +717,9 @@ ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPa return OutPackage.TryLoad(Response); } -CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +CbPackageReader::CbPackageReader(ParseFlags Flags) +: m_Flags(Flags) +, m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) { } @@ -644,6 +733,12 @@ CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Ci m_CreateBuffer = CreateBuffer; } +void +CbPackageReader::SetLocalRefPolicy(const ILocalRefPolicy* Policy) +{ + m_LocalRefPolicy = Policy; +} + uint64_t CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) { @@ -657,12 +752,14 @@ CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) return sizeof m_PackageHeader; case State::kReadingHeader: - ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); - memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); - ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); - m_CurrentState = State::kReadingAttachmentEntries; - m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); - return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + { + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + const uint32_t ChunkCount = ValidatePackageHeader(m_PackageHeader); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(ChunkCount); + return uint64_t(ChunkCount) * sizeof(CbAttachmentEntry); + } case State::kReadingAttachmentEntries: ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); @@ -691,16 +788,19 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) { // Marshal local reference - a "pointer" to the chunk backing file - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + ValidatedLocalRef LocalRef = ValidateLocalRef(AttachmentBuffer); + if (!LocalRef.Valid) + { + throw std::invalid_argument(LocalRef.Error); + } - std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + const CbAttachmentReferenceHeader* AttachRefHdr = LocalRef.Header; + std::filesystem::path Path(Utf8ToWide(LocalRef.Path)); - std::filesystem::path Path{PathView}; + if (!LocalRef.Path.starts_with(HandlePrefix)) + { + ApplyLocalRefPolicy(m_LocalRefPolicy, Path); + } IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); @@ -714,6 +814,17 @@ CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) AttachRefHdr->PayloadByteSize)); } + // MakeFromFile silently clamps offset+size to the file size. Detect this + // to avoid returning a short buffer that could cause subtle downstream issues. + if (ChunkReference.GetSize() != AttachRefHdr->PayloadByteSize) + { + throw std::invalid_argument(fmt::format("local ref offset/size out of bounds for '{}' (requested offset {}, size {}, got size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize, + ChunkReference.GetSize())); + } + return ChunkReference; }; @@ -732,6 +843,13 @@ CbPackageReader::Finalize() { IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + if ((Entry.Flags & CbAttachmentEntry::kIsLocalRef) && !EnumHasAllFlags(m_Flags, ParseFlags::kAllowLocalReferences)) + { + throw std::invalid_argument( + fmt::format("package contains local reference (attachment #{}) but local references are not allowed", + CurrentAttachmentIndex)); + } + if (CurrentAttachmentIndex == 0) { // Root object @@ -815,6 +933,13 @@ CbPackageReader::Finalize() TEST_SUITE_BEGIN("http.packageformat"); +/// Permissive policy that allows any path, for use in tests that exercise local ref +/// functionality but are not testing path validation. +struct PermissiveLocalRefPolicy : public ILocalRefPolicy +{ + void ValidatePath(const std::filesystem::path&) const override {} +}; + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -922,6 +1047,169 @@ TEST_CASE("CbPackage.LocalRef") RemainingBytes -= ByteCount; }; + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackageReader Reader(ParseFlags::kAllowLocalReferences); + Reader.SetLocalRefPolicy(&AllowAllPolicy); + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.Validation.TruncatedHeader") +{ + // Payload too small for a CbPackageHeader + uint8_t Bytes[] = {0xcc, 0xaa, 0x77, 0xaa}; + IoBuffer Payload(IoBuffer::Wrap, Bytes, sizeof(Bytes)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagic") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xDEADBEEF; + Hdr.AttachmentCount = 0; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflow") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentTable") +{ + // Valid header but not enough data for the attachment entries + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = 10; + IoBuffer Payload(IoBuffer::Wrap, &Hdr, sizeof(Hdr)); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.TruncatedAttachmentData") +{ + // Valid header + one attachment entry claiming more data than available + std::vector<uint8_t> Data(sizeof(CbPackageHeader) + sizeof(CbAttachmentEntry)); + + CbPackageHeader* Hdr = reinterpret_cast<CbPackageHeader*>(Data.data()); + Hdr->HeaderMagic = kCbPkgMagic; + Hdr->AttachmentCount = 0; // ChunkCount = 1 (root object) + + CbAttachmentEntry* Entry = reinterpret_cast<CbAttachmentEntry*>(Data.data() + sizeof(CbPackageHeader)); + Entry->PayloadSize = 9999; // way more than available + Entry->Flags = CbAttachmentEntry::kIsObject; + Entry->AttachmentHash = IoHash(); + + IoBuffer Payload(IoBuffer::Wrap, Data.data(), Data.size()); + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByDefault") +{ + // Build a valid package with local refs backed by compressed-format files, + // then verify it's rejected with default ParseFlags and accepted when allowed. + ScopedTemporaryDirectory TempDir; + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + // Compress data and write to disk, then create file-backed compressed attachments. + // The files must contain compressed-format data because ParsePackageMessage expects it + // when resolving local refs. + CompressedBuffer Comp1 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("abcd")), OodleCompressor::NotSet, OodleCompressionLevel::None); + CompressedBuffer Comp2 = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("efgh")), OodleCompressor::NotSet, OodleCompressionLevel::None); + + IoHash Hash1 = Comp1.DecodeRawHash(); + IoHash Hash2 = Comp2.DecodeRawHash(); + + { + IoBuffer Buf1 = Comp1.GetCompressed().Flatten().AsIoBuffer(); + IoBuffer Buf2 = Comp2.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(Path1, Buf1); + WriteFile(Path2, Buf2); + } + + // Create attachments from file-backed buffers so FormatPackageMessage uses local refs + CbAttachment Attach1{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path1)), Hash1}; + CbAttachment Attach2{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(Path2)), Hash2}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Default flags should reject local refs + CHECK_THROWS_AS(ParsePackageMessage(Payload), std::invalid_argument); + + // With kAllowLocalReferences + a permissive policy, the local-ref gate is passed (the full round-trip + // for local refs through ParsePackageMessage is covered by CbPackage.LocalRef via CbPackageReader) + PermissiveLocalRefPolicy AllowAllPolicy; + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &AllowAllPolicy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 2); +} + +TEST_CASE("CbPackage.Validation.LocalRefRejectedByReader") +{ + // Same test but via CbPackageReader + ScopedTemporaryDirectory TempDir; + auto FilePath = TempDir.Path() / "testdata"; + + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("testdata")); + WriteFile(FilePath, Buf); + } + + IoBuffer FileBuffer = IoBufferBuilder::MakeFromFile(FilePath); + CbAttachment Attach{SharedBuffer(FileBuffer)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + // Default flags should reject CbPackageReader Reader; uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); @@ -933,7 +1221,199 @@ TEST_CASE("CbPackage.LocalRef") CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); } - Reader.Finalize(); + CHECK_THROWS_AS(Reader.Finalize(), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.BadMagicViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = 0xBADCAFE; + Hdr.AttachmentCount = 0; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.Validation.AttachmentCountOverflowViaReader") +{ + CbPackageHeader Hdr{}; + Hdr.HeaderMagic = kCbPkgMagic; + Hdr.AttachmentCount = UINT32_MAX; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + CHECK_THROWS_AS(Reader.ProcessPackageHeaderData(&Hdr, InitialRead), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathOutsideRoot") +{ + // A file outside the allowed root should be rejected by the policy + ScopedTemporaryDirectory AllowedRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "secret.dat"; + { + IoBuffer Buf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("secret")); + WriteFile(OutsidePath, Buf); + } + + // Create file-backed compressed attachment from outside root + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("secret")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // Policy rooted at AllowedRoot should reject the file in OutsideDir + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(AllowedRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathInsideRoot") +{ + // A file inside the allowed root should be accepted by the policy + ScopedTemporaryDirectory TempRoot; + + auto FilePath = TempRoot.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("hello")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CbPackage Result = ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy); + CHECK(Result.GetObject()); + CHECK(Result.GetAttachments().size() == 1); +} + +TEST_CASE("CbPackage.LocalRefPolicy.PathTraversal") +{ + // A file path containing ".." that resolves outside root should be rejected + ScopedTemporaryDirectory TempRoot; + ScopedTemporaryDirectory OutsideDir; + + auto OutsidePath = OutsideDir.Path() / "evil.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("evil")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(OutsidePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(OutsidePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + struct TestPolicy : public ILocalRefPolicy + { + std::string Root; + void ValidatePath(const std::filesystem::path& Path) const override + { + std::string CanonicalFile = std::filesystem::weakly_canonical(Path).string(); + if (CanonicalFile.size() < Root.size() || CanonicalFile.compare(0, Root.size(), Root) != 0) + { + throw std::invalid_argument("path outside root"); + } + } + } Policy; + // Root is TempRoot, but the file lives in OutsideDir + Policy.Root = std::filesystem::weakly_canonical(TempRoot.Path()).string(); + + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, &Policy), std::invalid_argument); +} + +TEST_CASE("CbPackage.LocalRefPolicy.NoPolicyFailClosed") +{ + // When local refs are allowed but no policy is provided, file-path refs should be rejected + ScopedTemporaryDirectory TempDir; + + auto FilePath = TempDir.Path() / "data.dat"; + + CompressedBuffer Comp = + CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView("data")), OodleCompressor::NotSet, OodleCompressionLevel::None); + IoHash Hash = Comp.DecodeRawHash(); + { + IoBuffer Buf = Comp.GetCompressed().Flatten().AsIoBuffer(); + WriteFile(FilePath, Buf); + } + + CbAttachment Attach{CompressedBuffer::FromCompressedNoValidate(IoBufferBuilder::MakeFromFile(FilePath)), Hash}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("data", Attach); + + CbPackage Pkg; + Pkg.AddAttachment(Attach); + Pkg.SetObject(Cbo.Save()); + + IoBuffer Payload = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten().AsIoBuffer(); + + // kAllowLocalReferences but nullptr policy => fail-closed + CHECK_THROWS_AS(ParsePackageMessage(Payload, {}, ParseFlags::kAllowLocalReferences, nullptr), std::invalid_argument); } TEST_SUITE_END(); diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 7972777b8..a1a775ba3 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -625,6 +625,8 @@ public: void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } /** * Initialize the response for sending a payload made up of multiple blobs @@ -768,10 +770,18 @@ public: { ZEN_MEMSCOPE(GetHttpasioTag()); + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -898,7 +908,9 @@ private: bool m_AllowZeroCopyFileSend = true; State m_State = State::kUninitialized; HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -1275,7 +1287,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::buffer(ResponseStr->data(), ResponseStr->size()), asio::bind_executor( m_Strand, - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr, PrefixLen = Service->UriPrefixLength()]( + const asio::error_code& Ec, + std::size_t) { if (Ec) { ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); @@ -1287,7 +1301,9 @@ HttpServerConnectionT<SocketType>::HandleRequest() Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + std::string_view FullUrl = Conn->m_RequestData.Url(); + std::string_view RelativeUri = FullUrl.substr(std::min(PrefixLen, static_cast<int>(FullUrl.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); })); @@ -1295,7 +1311,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() return; } } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } if (!m_RequestData.IsKeepAlive()) @@ -2127,6 +2143,10 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2142,6 +2162,14 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->SetKeepAlive(m_Request.IsKeepAlive()); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 918b55dc6..8b07c7905 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -8,6 +8,13 @@ #include <limits> +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <cstring> +# include <string> +# include <string_view> +#endif + namespace zen { using namespace std::literals; @@ -29,25 +36,25 @@ static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-W // HttpRequestParser // -http_parser_settings HttpRequestParser::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; +// clang-format off +llhttp_settings_t HttpRequestParser::s_ParserSettings = []() { + llhttp_settings_t S; + llhttp_settings_init(&S); + S.on_message_begin = [](llhttp_t* p) { return GetThis(p)->OnMessageBegin(); }; + S.on_url = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }; + S.on_status = [](llhttp_t*, const char*, size_t) { return 0; }; + S.on_header_field = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }; + S.on_header_value = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }; + S.on_headers_complete = [](llhttp_t* p) { return GetThis(p)->OnHeadersComplete(); }; + S.on_body = [](llhttp_t* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }; + S.on_message_complete = [](llhttp_t* p) { return GetThis(p)->OnMessageComplete(); }; + return S; +}(); +// clang-format on HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { - http_parser_init(&m_Parser, HTTP_REQUEST); + llhttp_init(&m_Parser, HTTP_REQUEST, &s_ParserSettings); m_Parser.data = this; ResetState(); @@ -60,16 +67,17 @@ HttpRequestParser::~HttpRequestParser() size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + llhttp_errno_t Err = llhttp_execute(&m_Parser, InputData, DataSize); + if (Err == HPE_OK) { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; + return DataSize; } - return ConsumedBytes; + if (Err == HPE_PAUSED_UPGRADE) + { + return DataSize; + } + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", llhttp_errno_name(Err), llhttp_get_error_reason(&m_Parser)); + return ~0ull; } int @@ -79,7 +87,7 @@ HttpRequestParser::OnUrl(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_UrlRange.Length == 0) @@ -101,7 +109,7 @@ HttpRequestParser::OnHeader(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } if (m_HeaderEntries.empty()) @@ -212,7 +220,7 @@ HttpRequestParser::OnHeaderValue(const char* Data, size_t Bytes) if (RemainingBufferSpace < Bytes) { ZEN_WARN("HTTP parser does not have enough space for incoming request headers, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; + return -1; } ZEN_ASSERT_SLOW(!m_HeaderEntries.empty()); @@ -269,9 +277,9 @@ HttpRequestParser::OnHeadersComplete() } } - m_KeepAlive = !!http_should_keep_alive(&m_Parser); + m_KeepAlive = !!llhttp_should_keep_alive(&m_Parser); - switch (m_Parser.method) + switch (llhttp_get_method(&m_Parser)) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; @@ -302,7 +310,7 @@ HttpRequestParser::OnHeadersComplete() break; default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + ZEN_WARN("invalid HTTP method: '{}'", llhttp_method_name(static_cast<llhttp_method_t>(llhttp_get_method(&m_Parser)))); break; } @@ -349,20 +357,11 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; + return -1; } memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - return 0; } @@ -409,7 +408,7 @@ HttpRequestParser::OnMessageComplete() catch (const AssertException& AssertEx) { ZEN_WARN("Assert caught when processing http request: {}", AssertEx.FullDescription()); - return 1; + return -1; } catch (const std::system_error& SystemError) { @@ -426,19 +425,19 @@ HttpRequestParser::OnMessageComplete() ZEN_ERROR("failed processing http request: '{}' ({})", SystemError.what(), SystemError.code().value()); } ResetState(); - return 1; + return -1; } catch (const std::bad_alloc& BadAlloc) { ZEN_WARN("out of memory when processing http request: '{}'", BadAlloc.what()); ResetState(); - return 1; + return -1; } catch (const std::exception& Ex) { ZEN_ERROR("failed processing http request: '{}'", Ex.what()); ResetState(); - return 1; + return -1; } } @@ -459,4 +458,331 @@ HttpRequestParser::IsWebSocketUpgrade() const return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; } +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace { + + struct MockCallbacks : HttpRequestParserCallbacks + { + int HandleRequestCount = 0; + int TerminateCount = 0; + + HttpRequestParser* Parser = nullptr; + + HttpVerb LastVerb{}; + std::string LastUrl; + std::string LastQueryString; + std::string LastBody; + bool LastKeepAlive = false; + bool LastIsWebSocketUpgrade = false; + std::string LastSecWebSocketKey; + std::string LastUpgradeHeader; + HttpContentType LastContentType{}; + + void HandleRequest() override + { + ++HandleRequestCount; + if (Parser) + { + LastVerb = Parser->RequestVerb(); + LastUrl = std::string(Parser->Url()); + LastQueryString = std::string(Parser->QueryString()); + LastKeepAlive = Parser->IsKeepAlive(); + LastIsWebSocketUpgrade = Parser->IsWebSocketUpgrade(); + LastSecWebSocketKey = std::string(Parser->SecWebSocketKey()); + LastUpgradeHeader = std::string(Parser->UpgradeHeader()); + LastContentType = Parser->ContentType(); + + IoBuffer Body = Parser->Body(); + if (Body.Size() > 0) + { + LastBody.assign(reinterpret_cast<const char*>(Body.Data()), Body.Size()); + } + else + { + LastBody.clear(); + } + } + } + + void TerminateConnection() override { ++TerminateCount; } + }; + +} // anonymous namespace + +TEST_SUITE_BEGIN("http.httpparser"); + +TEST_CASE("httpparser.basic_get") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kGet); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK(Mock.LastKeepAlive); +} + +TEST_CASE("httpparser.post_with_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 13\r\n" + "Content-Type: application/json\r\n" + "\r\n" + "{\"key\":\"val\"}"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, HttpVerb::kPost); + CHECK_EQ(Mock.LastBody, "{\"key\":\"val\"}"); + CHECK_EQ(Mock.LastContentType, HttpContentType::kJSON); +} + +TEST_CASE("httpparser.pipelined_requests") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /first HTTP/1.1\r\nHost: localhost\r\n\r\n" + "GET /second HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 2); + CHECK_EQ(Mock.LastUrl, "/second"); +} + +TEST_CASE("httpparser.partial_header") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Chunk1 = "GET /path HTTP/1.1\r\nHost: loc"; + std::string Chunk2 = "alhost\r\n\r\n"; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(Chunk2.data(), Chunk2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, Chunk2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); +} + +TEST_CASE("httpparser.partial_body") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Headers = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 10\r\n" + "\r\n"; + std::string BodyPart1 = "hello"; + std::string BodyPart2 = "world"; + + std::string Chunk1 = Headers + BodyPart1; + + size_t Consumed1 = Parser.ConsumeData(Chunk1.data(), Chunk1.size()); + CHECK_NE(Consumed1, ~0ull); + CHECK_EQ(Consumed1, Chunk1.size()); + CHECK_EQ(Mock.HandleRequestCount, 0); + + size_t Consumed2 = Parser.ConsumeData(BodyPart2.data(), BodyPart2.size()); + CHECK_NE(Consumed2, ~0ull); + CHECK_EQ(Consumed2, BodyPart2.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "helloworld"); +} + +TEST_CASE("httpparser.invalid_request") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Garbage = "NOT_HTTP garbage data\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Garbage.data(), Garbage.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 0); +} + +TEST_CASE("httpparser.body_overflow") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + // llhttp enforces Content-Length strictly: it delivers exactly 2 body bytes, + // fires on_message_complete, then tries to parse the remaining "O_LONG_BODY" + // as a new HTTP request which fails. + std::string Request = + "POST /api HTTP/1.1\r\n" + "Host: localhost\r\n" + "Content-Length: 2\r\n" + "\r\n" + "TOO_LONG_BODY"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastBody, "TO"); +} + +TEST_CASE("httpparser.websocket_upgrade") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); + CHECK_EQ(Mock.LastSecWebSocketKey, "dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(Mock.LastUpgradeHeader, "websocket"); +} + +TEST_CASE("httpparser.websocket_upgrade_with_trailing_bytes") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string HttpPart = + "GET /ws HTTP/1.1\r\n" + "Host: localhost\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + // Append fake WebSocket frame bytes after the HTTP message + std::string Request = HttpPart; + Request.push_back('\x81'); + Request.push_back('\x05'); + Request.append("hello"); + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_NE(Consumed, ~0ull); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK(Mock.LastIsWebSocketUpgrade); +} + +TEST_CASE("httpparser.keep_alive_detection") +{ + SUBCASE("HTTP/1.1 default keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK(Mock.LastKeepAlive); + } + + SUBCASE("Connection: close disables keep-alive") + { + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + Parser.ConsumeData(Request.data(), Request.size()); + CHECK_FALSE(Mock.LastKeepAlive); + } +} + +TEST_CASE("httpparser.all_verbs") +{ + struct VerbTest + { + const char* Method; + HttpVerb Expected; + }; + + VerbTest Tests[] = { + {"GET", HttpVerb::kGet}, + {"POST", HttpVerb::kPost}, + {"PUT", HttpVerb::kPut}, + {"DELETE", HttpVerb::kDelete}, + {"HEAD", HttpVerb::kHead}, + {"COPY", HttpVerb::kCopy}, + {"OPTIONS", HttpVerb::kOptions}, + }; + + for (const VerbTest& Test : Tests) + { + CAPTURE(Test.Method); + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = std::string(Test.Method) + " /path HTTP/1.1\r\nHost: localhost\r\n\r\n"; + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastVerb, Test.Expected); + } +} + +TEST_CASE("httpparser.query_string") +{ + MockCallbacks Mock; + HttpRequestParser Parser(Mock); + Mock.Parser = &Parser; + + std::string Request = "GET /path?key=val&other=123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + + size_t Consumed = Parser.ConsumeData(Request.data(), Request.size()); + CHECK_EQ(Consumed, Request.size()); + CHECK_EQ(Mock.HandleRequestCount, 1); + CHECK_EQ(Mock.LastUrl, "/path"); + CHECK_EQ(Mock.LastQueryString, "key=val&other=123"); +} + +TEST_SUITE_END(); + +void +httpparser_forcelink() +{ +} + +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 23ad9d8fb..4ff216248 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -8,7 +8,7 @@ #include <EASTL/fixed_vector.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <http_parser.h> +#include <llhttp.h> ZEN_THIRD_PARTY_INCLUDES_END #include <atomic> @@ -100,7 +100,7 @@ private: Oid m_SessionId{}; IoBuffer m_BodyBuffer; uint64_t m_BodyPosition = 0; - http_parser m_Parser; + llhttp_t m_Parser; eastl::fixed_vector<char, 512> m_HeaderData; std::string m_NormalizedUrl; @@ -114,8 +114,8 @@ private: int OnBody(const char* Data, size_t Bytes); int OnMessageComplete(); - static HttpRequestParser* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } - static http_parser_settings s_ParserSettings; + static HttpRequestParser* GetThis(llhttp_t* Parser) { return reinterpret_cast<HttpRequestParser*>(Parser->data); } + static llhttp_settings_t s_ParserSettings; }; } // namespace zen diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 31b0315d4..b0fb020e0 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -185,13 +185,17 @@ public: const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } void SuppressPayload() { m_ResponseBuffers.resize(1); } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } std::string_view GetHeaders(); private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; uint64_t m_ContentLength = 0; std::vector<IoBuffer> m_ResponseBuffers; ExtendableStringBuilder<160> m_Headers; @@ -246,10 +250,18 @@ HttpPluginResponse::GetHeaders() if (m_Headers.Size() == 0) { + std::string_view ContentTypeStr = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Type: " << ContentTypeStr << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; + if (!m_ContentRangeHeader.empty()) + { + m_Headers << "Content-Range: " << m_ContentRangeHeader << "\r\n"sv; + } + if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; @@ -669,6 +681,10 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -681,6 +697,14 @@ HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpConten ZEN_MEMSCOPE(GetHttppluginTag()); m_Response.reset(new HttpPluginResponse(ContentType)); + if (!m_ContentTypeOverride.empty()) + { + m_Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + m_Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 2cad97725..67b1230a0 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -464,6 +464,8 @@ public: inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + void SetContentTypeOverride(std::string Override) { m_ContentTypeOverride = std::move(Override); } + void SetContentRangeHeader(std::string V) { m_ContentRangeHeader = std::move(V); } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -473,6 +475,8 @@ private: uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; + std::string m_ContentTypeOverride; + std::string m_ContentRangeHeader; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; std::string m_LocationHeader; @@ -725,7 +729,8 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + std::string_view ContentTypeString = + m_ContentTypeOverride.empty() ? MapContentTypeToString(m_ContentType) : std::string_view(m_ContentTypeOverride); ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); @@ -739,6 +744,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); } + // Content-Range header (for 206 Partial Content single-range responses) + + if (!m_ContentRangeHeader.empty()) + { + PHTTP_KNOWN_HEADER ContentRangeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentRange]; + ContentRangeHeader->pRawValue = m_ContentRangeHeader.data(); + ContentRangeHeader->RawValueLength = (USHORT)m_ContentRangeHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -1258,7 +1272,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) { // Port may be owned by another process's wildcard registration (access denied) - // or actively in use (sharing violation) — retry on a different port + // or actively in use (sharing violation) - retry on a different port ShouldRetryNextPort = true; } else @@ -2279,6 +2293,11 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2307,6 +2326,15 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + if (!m_ContentTypeOverride.empty()) + { + Response->SetContentTypeOverride(std::move(m_ContentTypeOverride)); + } + if (!m_ContentRangeHeader.empty()) + { + Response->SetContentRangeHeader(std::move(m_ContentRangeHeader)); + } + if (SuppressBody()) { Response->SuppressResponseBody(); @@ -2595,7 +2623,14 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT &Transaction().Server())); Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + ExtendableStringBuilder<128> UrlUtf8; + WideToUtf8({(wchar_t*)HttpReq->CookedUrl.pAbsPath, + gsl::narrow<size_t>(HttpReq->CookedUrl.AbsPathLength / sizeof(wchar_t))}, + UrlUtf8); + int PrefixLen = Service->UriPrefixLength(); + std::string_view RelativeUri{UrlUtf8.ToView()}; + RelativeUri.remove_prefix(std::min(PrefixLen, static_cast<int>(RelativeUri.size()))); + WsHandler->OnWebSocketOpen(std::move(WsConnRef), RelativeUri); WsConn->Start(); return nullptr; @@ -2603,11 +2638,11 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - // WebSocket upgrade failed — return nullptr since ServerRequest() + // WebSocket upgrade failed - return nullptr since ServerRequest() // was never populated (no InvokeRequestHandler call) return nullptr; } - // Service doesn't support WebSocket or missing key — fall through to normal handling + // Service doesn't support WebSocket or missing key - fall through to normal handling } } diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 5ae48f5b3..078c21ea1 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -141,7 +141,7 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index af320172d..8520e9f60 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -70,7 +70,7 @@ WsHttpSysConnection::Shutdown() return; } - // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + // Cancel pending I/O - completions will fire with ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } @@ -211,7 +211,7 @@ WsHttpSysConnection::ProcessReceivedData() } case WebSocketOpcode::kPong: - // Unsolicited pong — ignore per RFC 6455 + // Unsolicited pong - ignore per RFC 6455 break; case WebSocketOpcode::kClose: @@ -446,7 +446,7 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); - // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + // Cancel pending read I/O - completions drain via ERROR_OPERATION_ABORTED HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); } diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 59c46a418..a58037fec 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -5,6 +5,7 @@ # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zencore/timer.h> # include <zenhttp/httpserver.h> # include <zenhttp/httpwsclient.h> @@ -59,7 +60,7 @@ TEST_CASE("websocket.framecodec") std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); - // Server frames are unmasked — TryParseFrame should handle them + // Server frames are unmasked - TryParseFrame should handle them WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); @@ -129,7 +130,7 @@ TEST_CASE("websocket.framecodec") { std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); - // Pass only 1 byte — not enough for a frame header + // Pass only 1 byte - not enough for a frame header WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); CHECK_FALSE(Result.IsValid); CHECK_EQ(Result.BytesConsumed, 0u); @@ -335,8 +336,9 @@ namespace { } // IWebSocketHandler - void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + void OnWebSocketOpen(Ref<WebSocketConnection> Connection, std::string_view RelativeUri) override { + ZEN_UNUSED(RelativeUri); m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); @@ -463,7 +465,7 @@ namespace { if (!Done.load()) { - // Timeout — cancel the read + // Timeout - cancel the read asio::error_code Ec; Sock.cancel(Ec); } @@ -476,6 +478,23 @@ namespace { return Result; } + static void WaitForServerListening(int Port) + { + Stopwatch Timer; + while (Timer.GetElapsedTimeMs() < 5'000) + { + asio::io_context IoCtx; + asio::ip::tcp::socket Probe(IoCtx); + asio::error_code Ec; + Probe.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)), Ec); + if (!Ec) + { + return; + } + Sleep(10); + } + } + } // anonymous namespace TEST_CASE("websocket.integration") @@ -501,8 +520,8 @@ TEST_CASE("websocket.integration") Server->Close(); }); - // Give server a moment to start accepting - Sleep(100); + // Wait for server to start accepting + WaitForServerListening(Port); SUBCASE("handshake succeeds with 101") { @@ -692,7 +711,7 @@ TEST_CASE("websocket.integration") std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); - // Should NOT get 101 — should fall through to normal request handling + // Should NOT get 101 - should fall through to normal request handling CHECK(Response.find("101") == std::string::npos); Sock.close(); @@ -813,7 +832,7 @@ TEST_CASE("websocket.client") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close") { @@ -937,7 +956,7 @@ TEST_CASE("websocket.client.unixsocket") Server->Close(); }); - Sleep(100); + WaitForServerListening(Port); SUBCASE("connect, echo, close over unix socket") { diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 7b050ae35..67a01403d 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -9,7 +9,7 @@ target('zenhttp') add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio") - add_packages("http_parser", "json11", "libcurl") + add_packages("llhttp", "json11", "libcurl") add_options("httpsys") if is_plat("linux", "macosx") then diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 3ac8eea8d..e15aa4d30 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -4,6 +4,7 @@ #if ZEN_WITH_TESTS +# include <zenhttp/asynchttpclient.h> # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> # include <zenhttp/packageformat.h> @@ -16,7 +17,9 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpparser_forcelink(); httpclient_test_forcelink(); + asynchttpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); websocket_forcelink(); |