From c37421a3b4493c0b0f9afef15a4ea7b74d152067 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Thu, 12 Feb 2026 10:58:41 +0100 Subject: add simple http client tests (#751) * add simple http client tests and fix run loop of http server to not rely on application quit --- src/zenhttp/httpclient.cpp | 164 +++++++++++++++++++++++++++- src/zenhttp/include/zenhttp/httpserver.h | 4 +- src/zenhttp/servers/httpasio.cpp | 17 +-- src/zenhttp/servers/httpasio.h | 2 + src/zenhttp/servers/httpmulti.cpp | 11 +- src/zenhttp/servers/httpnull.cpp | 11 +- src/zenhttp/servers/httpsys.cpp | 18 +-- src/zenhttp/transports/winsocktransport.cpp | 2 +- 8 files changed, 202 insertions(+), 27 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 43e9fb468..0544bf5c8 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -22,8 +22,13 @@ #include "clients/httpclientcommon.h" #if ZEN_WITH_TESTS +# include # include # include +# include "servers/httpasio.h" +# include "servers/httpsys.h" + +# include #endif // ZEN_WITH_TESTS namespace zen { @@ -388,7 +393,164 @@ TEST_CASE("httpclient") { using namespace std::literals; - SUBCASE("client") {} + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey man"); + } + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK); + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + SUBCASE("asio") + { + Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread SeverThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (SeverThread.joinable()) + { + SeverThread.join(); + } + AsioServer->Close(); + }); + + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } +# endif // 0 + AsioServer->RequestExit(); + } + } + +# if ZEN_PLATFORM_WINDOWS + SUBCASE("httpsys") + { + Ref HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = true}); + + int Port = HttpSysServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + HttpSysServer->RegisterService(TestService); + + std::thread SeverThread([&]() { HttpSysServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (SeverThread.joinable()) + { + SeverThread.join(); + } + HttpSysServer->Close(); + }); + + if (true) + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test"); + CHECK(TestResponse.IsSuccess()); + } +# endif // 0 + HttpSysServer->RequestExit(); + } + } +# endif // ZEN_PLATFORM_WINDOWS } void diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 3438a1471..6660bebf9 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -30,8 +30,10 @@ class HttpService; */ class HttpServerRequest { -public: +protected: explicit HttpServerRequest(HttpService& Service); + +public: ~HttpServerRequest(); // Synchronous operations diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 18a0f6a40..76fea65b3 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -97,7 +97,11 @@ IsIPv6AvailableSysctl(void) return val == 0; } +#endif // ZEN_PLATFORM_LINUX +namespace zen { + +#if ZEN_PLATFORM_LINUX bool IsIPv6Capable() { @@ -121,8 +125,6 @@ IsIPv6Capable() } #endif -namespace zen { - const FLLMTag& GetHttpasioTag() { @@ -1992,7 +1994,8 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) void HttpAsioServer::OnRun(bool IsInteractive) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractive) @@ -2012,8 +2015,8 @@ HttpAsioServer::OnRun(bool IsInteractive) } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractive) { @@ -2022,8 +2025,8 @@ HttpAsioServer::OnRun(bool IsInteractive) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index c483dfc28..3ec1141a7 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -15,4 +15,6 @@ struct AsioConfig Ref CreateHttpAsioServer(const AsioConfig& Config); +bool IsIPv6Capable(); + } // namespace zen diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 31cb04be5..95624245f 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -56,7 +56,8 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) void HttpMultiServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -76,8 +77,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -86,8 +87,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index 0ec1cb3c4..b770b97db 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -34,7 +34,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) void HttpNullServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -54,8 +55,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession) } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -64,8 +65,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 54cc0c22d..0d2bb8fbd 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -1128,11 +1128,14 @@ HttpSysServer::InitializeServer(int BasePort) // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user= - ZEN_WARN( - "Unable to register handler using '{}' - falling back to local-only. " - "Please ensure the appropriate netsh URL reservation configuration " - "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", - WideToUtf8(WildcardUrlPath)); + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation configuration " + "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", + WideToUtf8(WildcardUrlPath)); + } const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; @@ -1337,6 +1340,7 @@ HttpSysServer::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { // int WaitTimeout = -1; @@ -1357,9 +1361,9 @@ HttpSysServer::OnRun(bool IsInteractive) } } - m_ShutdownEvent.Wait(WaitTimeout); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); + } while (!ShutdownRequested); } void diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp index c06a50c95..0217ed44e 100644 --- a/src/zenhttp/transports/winsocktransport.cpp +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -322,7 +322,7 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface) else { } - } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + } while (m_KeepRunning.test()); ZEN_INFO("HTTP plugin server accept thread exit"); }); -- cgit v1.2.3 From 3a563f5e8fcabffe686e1deb5862bdf39078ebdf Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Thu, 12 Feb 2026 15:25:05 +0100 Subject: add IsLocalMachineRequest to HttpServerRequest (#749) * add IsLocalMachineRequest to HttpServerRequest --- src/zenhttp/httpclient.cpp | 23 ++++++++++++++++++++--- src/zenhttp/httpserver.cpp | 5 ++++- src/zenhttp/include/zenhttp/httpserver.h | 2 ++ src/zenhttp/servers/httpasio.cpp | 23 ++++++++++++++++++++--- src/zenhttp/servers/httpplugin.cpp | 3 +++ src/zenhttp/servers/httpsys.cpp | 29 +++++++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 0544bf5c8..c77be8624 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -402,7 +402,14 @@ TEST_CASE("httpclient") { if (HttpServiceRequest.RelativeUri() == "yo") { - return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey man"); + if (HttpServiceRequest.IsLocalMachineRequest()) + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + else + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger"); + } } return HttpServiceRequest.WriteResponse(HttpResponseCode::OK); } @@ -440,6 +447,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } if (IsIPv6Capable()) @@ -452,6 +460,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } { @@ -463,6 +472,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } # if 0 { @@ -474,7 +484,9 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } + Sleep(20000); # endif // 0 AsioServer->RequestExit(); } @@ -483,7 +495,7 @@ TEST_CASE("httpclient") # if ZEN_PLATFORM_WINDOWS SUBCASE("httpsys") { - Ref HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = true}); + Ref HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false}); int Port = HttpSysServer->Initialize(7575, TmpDir.Path()); REQUIRE(Port != -1); @@ -511,6 +523,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } if (IsIPv6Capable()) @@ -523,6 +536,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } { @@ -534,6 +548,7 @@ TEST_CASE("httpclient") HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } # if 0 { @@ -543,9 +558,11 @@ TEST_CASE("httpclient") ZEN_INFO("Request using {}", Client.GetBaseUri()); - HttpClient::Response TestResponse = Client.Get("/test"); + HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); } + Sleep(20000); # endif // 0 HttpSysServer->RequestExit(); } diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index c4e67d4ed..8985120b0 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -1310,7 +1310,10 @@ TEST_CASE("http.common") { TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override + + virtual bool IsLocalMachineRequest() const override { return false; } + + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override { ZEN_UNUSED(ResponseCode, ContentType, Blobs); } diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 6660bebf9..f0a667686 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -89,6 +89,8 @@ public: CbObject ReadPayloadObject(); CbPackage ReadPayloadPackage(); + virtual bool IsLocalMachineRequest() const = 0; + /** Respond with payload No data will have been sent when any of these functions return. Instead, the response will be transmitted diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 76fea65b3..8bfbd8b37 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -529,12 +529,18 @@ public: class HttpAsioServerRequest : public HttpServerRequest { public: - HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber); + HttpAsioServerRequest(HttpRequestParser& Request, + HttpService& Service, + IoBuffer PayloadBuffer, + uint32_t RequestNumber, + bool IsLocalMachineRequest); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; @@ -550,6 +556,7 @@ public: HttpRequestParser& m_Request; uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; + bool m_IsLocalMachineRequest; std::unique_ptr m_Response; }; @@ -1168,7 +1175,9 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber); + bool IsLocalConnection = m_Socket->local_endpoint().address() == m_Socket->remote_endpoint().address(); + + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection); ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); @@ -1634,11 +1643,13 @@ private: HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, - uint32_t RequestNumber) + uint32_t RequestNumber, + bool IsLocalMachineRequest) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) +, m_IsLocalMachineRequest(IsLocalMachineRequest) { const int PrefixLength = Service.UriPrefixLength(); @@ -1710,6 +1721,12 @@ HttpAsioServerRequest::ParseRequestId() const return m_Request.RequestId(); } +bool +HttpAsioServerRequest::IsLocalMachineRequest() const +{ + return m_IsLocalMachineRequest; +} + IoBuffer HttpAsioServerRequest::ReadPayload() { diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index b9217ed87..1a630c16f 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -143,6 +143,9 @@ public: HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; + // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection + virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 0d2bb8fbd..01c4559a1 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -313,6 +313,8 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; @@ -1814,6 +1816,33 @@ HttpSysServerRequest::ParseRequestId() const return 0; } +bool +HttpSysServerRequest::IsLocalMachineRequest() const +{ + const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress; + const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + if (LocalAddress->sa_family != RemoteAddress->sa_family) + { + return false; + } + if (LocalAddress->sa_family == AF_INET) + { + const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress); + const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress); + return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr; + } + else if (LocalAddress->sa_family == AF_INET6) + { + const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress); + const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress); + return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0; + } + else + { + return false; + } +} + IoBuffer HttpSysServerRequest::ReadPayload() { -- cgit v1.2.3 From 3b5b777900d9f59ff32eb7cea79e3a72a08c67a6 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 13 Feb 2026 13:27:08 +0100 Subject: add IHttpRequestFilter to allow server implementation to filter/reject requests (#753) * add IHttpRequestFilter to allow server implementation to filter/reject requests --- src/zenhttp/httpclient.cpp | 104 ++++++++++++++++++++++-- src/zenhttp/httpserver.cpp | 12 ++- src/zenhttp/include/zenhttp/httpserver.h | 22 +++++- src/zenhttp/servers/httpasio.cpp | 131 +++++++++++++++++++++---------- src/zenhttp/servers/httpmulti.cpp | 9 +++ src/zenhttp/servers/httpmulti.h | 1 + src/zenhttp/servers/httpnull.cpp | 6 ++ src/zenhttp/servers/httpnull.h | 1 + src/zenhttp/servers/httpplugin.cpp | 117 ++++++++++++++++++--------- src/zenhttp/servers/httpsys.cpp | 46 ++++++++++- 10 files changed, 353 insertions(+), 96 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index c77be8624..16729ce38 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -427,13 +427,13 @@ TEST_CASE("httpclient") AsioServer->RegisterService(TestService); - std::thread SeverThread([&]() { AsioServer->Run(false); }); + std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } AsioServer->Close(); }); @@ -502,13 +502,13 @@ TEST_CASE("httpclient") HttpSysServer->RegisterService(TestService); - std::thread SeverThread([&]() { HttpSysServer->Run(false); }); + std::thread ServerThread([&]() { HttpSysServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } HttpSysServer->Close(); }); @@ -570,6 +570,98 @@ TEST_CASE("httpclient") # endif // ZEN_PLATFORM_WINDOWS } +TEST_CASE("httpclient.requestfilter") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + class MyFilterImpl : public IHttpRequestFilter + { + public: + virtual Result FilterRequest(HttpServerRequest& Request) + { + if (Request.RelativeUri() == "should_filter") + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); + return Result::ResponseSent; + } + else if (Request.RelativeUri() == "should_forbid") + { + return Result::Forbidden; + } + return Result::Accepted; + } + }; + + MyFilterImpl MyFilter; + + Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response YoResponse = Client.Get("/test/yo"); + CHECK(YoResponse.IsSuccess()); + CHECK_EQ(YoResponse.AsText(), "hey family"); + + HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); + CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); + CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); + + HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); + CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); + + AsioServer->RequestExit(); + } +} + void httpclient_forcelink() { diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 8985120b0..d8367fcb2 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -463,7 +463,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) ////////////////////////////////////////////////////////////////////////// -HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri()) +HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) { } @@ -970,7 +970,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -994,7 +994,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -1052,6 +1052,12 @@ HttpServer::EnumerateServices(std::function&& Callba } } +void +HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + OnSetHttpRequestFilter(RequestFilter); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index f0a667686..60f6bc9f2 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -41,7 +41,7 @@ public: [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix + [[nodiscard]] inline HttpService& Service() const { return m_Service; } struct QueryParams { @@ -121,13 +121,14 @@ protected: kHaveSessionId = 1 << 3, }; - mutable uint32_t m_Flags = 0; + mutable uint32_t m_Flags = 0; + + HttpService& m_Service; // Service handling this request HttpVerb m_Verb = HttpVerb::kGet; HttpContentType m_ContentType = HttpContentType::kBinary; HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; uint64_t m_ContentLength = ~0ull; - std::string_view m_BaseUri; // Base URI path of the service handling this request - std::string_view m_Uri; // URI without service prefix + std::string_view m_Uri; // URI without service prefix std::string_view m_UriWithExtension; std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); @@ -148,6 +149,17 @@ public: virtual void OnRequestComplete() = 0; }; +struct IHttpRequestFilter +{ + enum class Result + { + Forbidden, + ResponseSent, + Accepted + }; + virtual Result FilterRequest(HttpServerRequest& Request) = 0; +}; + /** * Base class for implementing an HTTP "service" * @@ -184,6 +196,7 @@ class HttpServer : public RefCounted public: void RegisterService(HttpService& Service); void EnumerateServices(std::function&& Callback); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); int Initialize(int BasePort, std::filesystem::path DataDir); void Run(bool IsInteractiveSession); @@ -195,6 +208,7 @@ private: virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0; virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 8bfbd8b37..230aac6a8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -498,16 +498,19 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); - void Initialize(std::filesystem::path DataDir); - int Start(uint16_t Port, const AsioConfig& Config); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); + void Initialize(std::filesystem::path DataDir); + int Start(uint16_t Port, const AsioConfig& Config); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr m_Acceptor; std::vector m_ThreadPool; + std::atomic m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -1199,53 +1202,65 @@ HttpServerConnection::HandleRequest() std::vector{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr Response = std::move(Request.m_Response)) { @@ -1923,6 +1938,31 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +void +HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1937,6 +1977,7 @@ public: virtual void OnRegisterService(HttpService& Service) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; virtual void OnClose() override; @@ -1984,6 +2025,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service) m_Impl->RegisterService(Service.BaseUri(), Service); } +void +HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + m_Impl->SetHttpRequestFilter(RequestFilter); +} + int HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 95624245f..850d7d6b9 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -53,6 +53,15 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) return EffectivePort; } +void +HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + for (auto& Server : m_Servers) + { + Server->SetHttpRequestFilter(RequestFilter); + } +} + void HttpMultiServer::OnRun(bool IsInteractiveSession) { diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index ae0ed74cf..1897587a9 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -16,6 +16,7 @@ public: ~HttpMultiServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index b770b97db..db360c5fb 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service) ZEN_UNUSED(Service); } +void +HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_UNUSED(RequestFilter); +} + int HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h index ce7230938..52838f012 100644 --- a/src/zenhttp/servers/httpnull.h +++ b/src/zenhttp/servers/httpnull.h @@ -18,6 +18,7 @@ public: ~HttpNullServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 1a630c16f..4219dc292 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer // HttpPluginServer virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; @@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer virtual void AddPlugin(Ref Plugin) override; virtual void RemovePlugin(Ref Plugin) override; - HttpService* RouteRequest(std::string_view Url); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); struct ServiceEntry { @@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* Service; }; - bool m_IsInitialized = false; + std::atomic m_HttpRequestFilter = nullptr; + bool m_IsInitialized = false; RwLock m_Lock; std::vector m_UriHandlers; std::vector> m_Plugins; @@ -395,53 +398,65 @@ HttpPluginConnectionHandler::HandleRequest() std::vector{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr Response = std::move(Request.m_Response)) { @@ -752,6 +767,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir return m_BasePort; } +void +HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + void HttpPluginServerImpl::OnClose() { @@ -897,6 +919,23 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +IHttpRequestFilter::Result +HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 01c4559a1..4df4cd079 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -94,6 +94,7 @@ public: virtual void OnRun(bool TestMode) override; virtual void OnRequestExit() override; virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnClose() override; WorkerThreadPool& WorkPool(); @@ -101,6 +102,8 @@ public: inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -137,6 +140,9 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic m_HttpRequestFilter = nullptr; }; } // namespace zen @@ -1672,9 +1678,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) { - Service.HandleRequest(ThisRequest); + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -2244,6 +2262,30 @@ HttpSysServer::OnRegisterService(HttpService& Service) RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + Ref CreateHttpSysServer(HttpSysConfig Config) { -- cgit v1.2.3 From b0a3de5fec8f4da8f9513b02bc2326aa6a0e7bd5 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 13 Feb 2026 13:47:51 +0100 Subject: logging config move to zenutil (#754) made logging config options from zenserver available in zen CLI --- src/zen/zen.cpp | 23 +- src/zen/zen.h | 5 +- src/zenserver-test/zenserver-test.cpp | 2 +- src/zenserver/config/config.cpp | 73 ++----- src/zenserver/config/config.h | 31 ++- src/zenserver/config/luaconfig.h | 2 +- src/zenserver/diag/logging.cpp | 12 +- src/zenserver/main.cpp | 2 +- src/zenserver/storage/zenstorageserver.cpp | 2 +- src/zenserver/zenserver.cpp | 12 +- src/zenutil/commandlineoptions.cpp | 239 --------------------- src/zenutil/config/commandlineoptions.cpp | 239 +++++++++++++++++++++ src/zenutil/config/environmentoptions.cpp | 84 ++++++++ src/zenutil/config/loggingconfig.cpp | 77 +++++++ src/zenutil/environmentoptions.cpp | 84 -------- src/zenutil/include/zenutil/commandlineoptions.h | 40 ---- .../include/zenutil/config/commandlineoptions.h | 40 ++++ .../include/zenutil/config/environmentoptions.h | 92 ++++++++ src/zenutil/include/zenutil/config/loggingconfig.h | 37 ++++ src/zenutil/include/zenutil/environmentoptions.h | 92 -------- src/zenutil/zenutil.cpp | 2 +- 21 files changed, 639 insertions(+), 551 deletions(-) delete mode 100644 src/zenutil/commandlineoptions.cpp create mode 100644 src/zenutil/config/commandlineoptions.cpp create mode 100644 src/zenutil/config/environmentoptions.cpp create mode 100644 src/zenutil/config/loggingconfig.cpp delete mode 100644 src/zenutil/environmentoptions.cpp delete mode 100644 src/zenutil/include/zenutil/commandlineoptions.h create mode 100644 src/zenutil/include/zenutil/config/commandlineoptions.h create mode 100644 src/zenutil/include/zenutil/config/environmentoptions.h create mode 100644 src/zenutil/include/zenutil/config/loggingconfig.h delete mode 100644 src/zenutil/include/zenutil/environmentoptions.h (limited to 'src') diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 09a2e4f91..25245c3d2 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -39,7 +39,7 @@ #include #include #include -#include +#include #include #include #include @@ -538,6 +538,9 @@ main(int argc, char** argv) Options.add_options()("corelimit", "Limit concurrency", cxxopts::value(CoreLimit)); + ZenLoggingCmdLineOptions LoggingCmdLineOptions; + LoggingCmdLineOptions.AddCliOptions(Options, GlobalOptions.LoggingConfig); + #if ZEN_WITH_TRACE // We only have this in options for command line help purposes - we parse these argument separately earlier using // GetTraceOptionsFromCommandline() @@ -624,8 +627,8 @@ main(int argc, char** argv) } LimitHardwareConcurrency(CoreLimit); -#if ZEN_USE_SENTRY +#if ZEN_USE_SENTRY { EnvironmentOptions EnvOptions; @@ -671,12 +674,20 @@ main(int argc, char** argv) } #endif - zen::LoggingOptions LogOptions; - LogOptions.IsDebug = GlobalOptions.IsDebug; - LogOptions.IsVerbose = GlobalOptions.IsVerbose; - LogOptions.AllowAsync = false; + LoggingCmdLineOptions.ApplyOptions(GlobalOptions.LoggingConfig); + + const LoggingOptions LogOptions = {.IsDebug = GlobalOptions.IsDebug, + .IsVerbose = GlobalOptions.IsVerbose, + .IsTest = false, + .AllowAsync = false, + .NoConsoleOutput = GlobalOptions.LoggingConfig.NoConsoleOutput, + .QuietConsole = GlobalOptions.LoggingConfig.QuietConsole, + .AbsLogFile = GlobalOptions.LoggingConfig.AbsLogFile, + .LogId = GlobalOptions.LoggingConfig.LogId}; zen::InitializeLogging(LogOptions); + ApplyLoggingOptions(Options, GlobalOptions.LoggingConfig); + std::set_terminate([]() { void* Frames[8]; uint32_t FrameCount = GetCallstack(2, 8, Frames); diff --git a/src/zen/zen.h b/src/zen/zen.h index 05d1e4ec8..e3481beea 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -5,7 +5,8 @@ #include #include #include -#include +#include +#include namespace zen { @@ -14,6 +15,8 @@ struct ZenCliOptions bool IsDebug = false; bool IsVerbose = false; + ZenLoggingConfig LoggingConfig; + // Arguments after " -- " on command line are passed through and not parsed std::string PassthroughCommandLine; std::string PassthroughArgs; diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index 9a42bb73d..4120dec1a 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -17,7 +17,7 @@ # include # include # include -# include +# include # include # include diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index 07913e891..2b77df642 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -16,8 +16,8 @@ #include #include #include -#include -#include +#include +#include ZEN_THIRD_PARTY_INCLUDES_START #include @@ -119,10 +119,17 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions ZenServerConfig& ServerOptions = m_ServerOptions; + // logging + + LuaOptions.AddOption("server.logid"sv, ServerOptions.LoggingConfig.LogId, "log-id"sv); + LuaOptions.AddOption("server.abslog"sv, ServerOptions.LoggingConfig.AbsLogFile, "abslog"sv); + LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.LoggingConfig.OtelEndpointUri, "otlp-endpoint"sv); + LuaOptions.AddOption("server.quiet"sv, ServerOptions.LoggingConfig.QuietConsole, "quiet"sv); + LuaOptions.AddOption("server.noconsole"sv, ServerOptions.LoggingConfig.NoConsoleOutput, "noconsole"sv); + // server LuaOptions.AddOption("server.dedicated"sv, ServerOptions.IsDedicated, "dedicated"sv); - LuaOptions.AddOption("server.logid"sv, ServerOptions.LogId, "log-id"sv); LuaOptions.AddOption("server.sentry.disable"sv, ServerOptions.SentryConfig.Disable, "no-sentry"sv); LuaOptions.AddOption("server.sentry.allowpersonalinfo"sv, ServerOptions.SentryConfig.AllowPII, "sentry-allow-personal-info"sv); LuaOptions.AddOption("server.sentry.dsn"sv, ServerOptions.SentryConfig.Dsn, "sentry-dsn"sv); @@ -131,12 +138,8 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions LuaOptions.AddOption("server.systemrootdir"sv, ServerOptions.SystemRootDir, "system-dir"sv); LuaOptions.AddOption("server.datadir"sv, ServerOptions.DataDir, "data-dir"sv); LuaOptions.AddOption("server.contentdir"sv, ServerOptions.ContentDir, "content-dir"sv); - LuaOptions.AddOption("server.abslog"sv, ServerOptions.AbsLogFile, "abslog"sv); - LuaOptions.AddOption("server.otlpendpoint"sv, ServerOptions.OtelEndpointUri, "otlp-endpoint"sv); LuaOptions.AddOption("server.debug"sv, ServerOptions.IsDebug, "debug"sv); LuaOptions.AddOption("server.clean"sv, ServerOptions.IsCleanStart, "clean"sv); - LuaOptions.AddOption("server.quiet"sv, ServerOptions.QuietConsole, "quiet"sv); - LuaOptions.AddOption("server.noconsole"sv, ServerOptions.NoConsoleOutput, "noconsole"sv); ////// network @@ -182,9 +185,10 @@ struct ZenServerCmdLineOptions std::string SystemRootDir; std::string ContentDir; std::string DataDir; - std::string AbsLogFile; std::string BaseSnapshotDir; + ZenLoggingCmdLineOptions LoggingOptions; + void AddCliOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions); void ApplyOptions(cxxopts::Options& options, ZenServerConfig& ServerOptions); }; @@ -249,22 +253,7 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi cxxopts::value(ServerOptions.ShouldCrash)->default_value("false"), ""); - // clang-format off - options.add_options("logging") - ("abslog", "Path to log file", cxxopts::value(AbsLogFile)) - ("log-id", "Specify id for adding context to log output", cxxopts::value(ServerOptions.LogId)) - ("quiet", "Configure console logger output to level WARN", cxxopts::value(ServerOptions.QuietConsole)->default_value("false")) - ("noconsole", "Disable console logging", cxxopts::value(ServerOptions.NoConsoleOutput)->default_value("false")) - ("log-trace", "Change selected loggers to level TRACE", cxxopts::value(ServerOptions.Loggers[logging::level::Trace])) - ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value(ServerOptions.Loggers[logging::level::Debug])) - ("log-info", "Change selected loggers to level INFO", cxxopts::value(ServerOptions.Loggers[logging::level::Info])) - ("log-warn", "Change selected loggers to level WARN", cxxopts::value(ServerOptions.Loggers[logging::level::Warn])) - ("log-error", "Change selected loggers to level ERROR", cxxopts::value(ServerOptions.Loggers[logging::level::Err])) - ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value(ServerOptions.Loggers[logging::level::Critical])) - ("log-off", "Change selected loggers to level OFF", cxxopts::value(ServerOptions.Loggers[logging::level::Off])) - ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value(ServerOptions.OtelEndpointUri)) - ; - // clang-format on + LoggingOptions.AddCliOptions(options, ServerOptions.LoggingConfig); options .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value(ServerOptions.OwnerPid), ""); @@ -394,9 +383,10 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.AbsLogFile = MakeSafeAbsolutePath(AbsLogFile); ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + + LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig); } ////////////////////////////////////////////////////////////////////////// @@ -466,34 +456,7 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[]) } #endif - if (m_ServerOptions.QuietConsole) - { - bool HasExplicitConsoleLevel = false; - for (int i = 0; i < logging::level::LogLevelCount; ++i) - { - if (m_ServerOptions.Loggers[i].find("console") != std::string::npos) - { - HasExplicitConsoleLevel = true; - break; - } - } - - if (!HasExplicitConsoleLevel) - { - std::string& WarnLoggers = m_ServerOptions.Loggers[logging::level::Warn]; - if (!WarnLoggers.empty()) - { - WarnLoggers += ","; - } - WarnLoggers += "console"; - } - } - - for (int i = 0; i < logging::level::LogLevelCount; ++i) - { - logging::ConfigureLogLevels(logging::level::LogLevel(i), m_ServerOptions.Loggers[i]); - } - logging::RefreshLogLevels(); + ApplyLoggingOptions(options, m_ServerOptions.LoggingConfig); BaseOptions.ApplyOptions(options, m_ServerOptions); ApplyOptions(options); @@ -532,9 +495,9 @@ ZenServerConfiguratorBase::Configure(int argc, char* argv[]) m_ServerOptions.DataDir = PickDefaultStateDirectory(m_ServerOptions.SystemRootDir); } - if (m_ServerOptions.AbsLogFile.empty()) + if (m_ServerOptions.LoggingConfig.AbsLogFile.empty()) { - m_ServerOptions.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log"; + m_ServerOptions.LoggingConfig.AbsLogFile = m_ServerOptions.DataDir / "logs" / "zenserver.log"; } m_ServerOptions.HttpConfig.IsDedicatedServer = m_ServerOptions.IsDedicated; diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h index 7c3192a1f..32c22cb05 100644 --- a/src/zenserver/config/config.h +++ b/src/zenserver/config/config.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -42,29 +43,25 @@ struct ZenServerConfig HttpServerConfig HttpConfig; ZenSentryConfig SentryConfig; ZenStatsConfig StatsConfig; - int BasePort = 8558; // Service listen port (used for both UDP and TCP) - int OwnerPid = 0; // Parent process id (zero for standalone) - bool IsDebug = false; - bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not - bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization - bool IsTest = false; - bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux) - bool NoConsoleOutput = false; // Control default use of stdout for diagnostics - bool QuietConsole = false; // Configure console logger output to level WARN - int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number - bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements - bool ShouldCrash = false; // Option for testing crash handling - bool IsFirstRun = false; + ZenLoggingConfig LoggingConfig; + int BasePort = 8558; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) + bool IsDebug = false; + bool IsCleanStart = false; // Indicates whether all state should be wiped on startup or not + bool IsPowerCycle = false; // When true, the process shuts down immediately after initialization + bool IsTest = false; + bool Detach = true; // Whether zenserver should detach from existing process group (Mac/Linux) + int CoreLimit = 0; // If set, hardware concurrency queries are capped at this number + int LieCpu = 0; + bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements + bool ShouldCrash = false; // Option for testing crash handling + bool IsFirstRun = false; std::filesystem::path ConfigFile; // Path to Lua config file std::filesystem::path SystemRootDir; // System root directory (used for machine level config) std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) std::filesystem::path DataDir; // Root directory for state (used for testing) - std::filesystem::path AbsLogFile; // Absolute path to main log file std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start) std::string ChildId; // Id assigned by parent process (used for lifetime management) - std::string LogId; // Id for tagging log output - std::string Loggers[zen::logging::level::LogLevelCount]; - std::string OtelEndpointUri; // OpenTelemetry endpoint URI #if ZEN_WITH_TRACE bool HasTraceCommandlineOptions = false; diff --git a/src/zenserver/config/luaconfig.h b/src/zenserver/config/luaconfig.h index ce7013a9a..e3ac3b343 100644 --- a/src/zenserver/config/luaconfig.h +++ b/src/zenserver/config/luaconfig.h @@ -4,7 +4,7 @@ #include #include -#include +#include ZEN_THIRD_PARTY_INCLUDES_START #include diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp index 4962b9006..75a8efc09 100644 --- a/src/zenserver/diag/logging.cpp +++ b/src/zenserver/diag/logging.cpp @@ -28,10 +28,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) const LoggingOptions LogOptions = {.IsDebug = InOptions.IsDebug, .IsVerbose = false, .IsTest = InOptions.IsTest, - .NoConsoleOutput = InOptions.NoConsoleOutput, - .QuietConsole = InOptions.QuietConsole, - .AbsLogFile = InOptions.AbsLogFile, - .LogId = InOptions.LogId}; + .NoConsoleOutput = InOptions.LoggingConfig.NoConsoleOutput, + .QuietConsole = InOptions.LoggingConfig.QuietConsole, + .AbsLogFile = InOptions.LoggingConfig.AbsLogFile, + .LogId = InOptions.LoggingConfig.LogId}; BeginInitializeLogging(LogOptions); @@ -79,10 +79,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) } #if ZEN_WITH_OTEL - if (!InOptions.OtelEndpointUri.empty()) + if (!InOptions.LoggingConfig.OtelEndpointUri.empty()) { // TODO: Should sanity check that endpoint is reachable? Also, a valid URI? - auto OtelSink = std::make_shared(InOptions.OtelEndpointUri); + auto OtelSink = std::make_shared(InOptions.LoggingConfig.OtelEndpointUri); zen::logging::Default().SpdLogger->sinks().push_back(std::move(OtelSink)); } #endif diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 3a58d1f4a..1a929b026 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include "diag/logging.h" diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index b2cae6482..2b74395c3 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -305,7 +305,7 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions *m_JobQueue, m_CacheStore.Get(), [this]() { Flush(); }, - HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.AbsLogFile, + HttpAdminService::LogPaths{.AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile, .HttpLogPath = ServerOptions.DataDir / "logs" / "http.log", .CacheLogPath = ServerOptions.DataDir / "logs" / "z$.log"}, ServerOptions); diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 2bafeeaa1..d54357368 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -152,7 +152,7 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: } m_HealthService.SetHealthInfo({.DataRoot = ServerOptions.DataDir, - .AbsLogPath = ServerOptions.AbsLogFile, + .AbsLogPath = ServerOptions.LoggingConfig.AbsLogFile, .HttpServerClass = std::string(ServerOptions.HttpConfig.ServerClass), .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)}); @@ -387,7 +387,7 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) // clang-format off std::list> Settings = { {"DataDir"sv, ServerConfig.DataDir.string()}, - {"AbsLogFile"sv, ServerConfig.AbsLogFile.string()}, + {"AbsLogFile"sv, ServerConfig.LoggingConfig.AbsLogFile.string()}, {"SystemRootDir"sv, ServerConfig.SystemRootDir.string()}, {"ContentDir"sv, ServerConfig.ContentDir.string()}, {"BasePort"sv, fmt::to_string(ServerConfig.BasePort)}, @@ -396,13 +396,13 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) {"IsPowerCycle"sv, fmt::to_string(ServerConfig.IsPowerCycle)}, {"IsTest"sv, fmt::to_string(ServerConfig.IsTest)}, {"Detach"sv, fmt::to_string(ServerConfig.Detach)}, - {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.NoConsoleOutput)}, - {"QuietConsole"sv, fmt::to_string(ServerConfig.QuietConsole)}, + {"NoConsoleOutput"sv, fmt::to_string(ServerConfig.LoggingConfig.NoConsoleOutput)}, + {"QuietConsole"sv, fmt::to_string(ServerConfig.LoggingConfig.QuietConsole)}, {"CoreLimit"sv, fmt::to_string(ServerConfig.CoreLimit)}, {"IsDedicated"sv, fmt::to_string(ServerConfig.IsDedicated)}, {"ShouldCrash"sv, fmt::to_string(ServerConfig.ShouldCrash)}, {"ChildId"sv, ServerConfig.ChildId}, - {"LogId"sv, ServerConfig.LogId}, + {"LogId"sv, ServerConfig.LoggingConfig.LogId}, {"Sentry DSN"sv, ServerConfig.SentryConfig.Dsn.empty() ? "not set" : ServerConfig.SentryConfig.Dsn}, {"Sentry Environment"sv, ServerConfig.SentryConfig.Environment}, {"Statsd Enabled"sv, fmt::to_string(ServerConfig.StatsConfig.Enabled)}, @@ -467,7 +467,7 @@ ZenServerMain::Run() ZEN_OTEL_SPAN("SentryInit"); std::string SentryDatabasePath = (m_ServerOptions.DataDir / ".sentry-native").string(); - std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string(); + std::string SentryAttachmentPath = m_ServerOptions.LoggingConfig.AbsLogFile.string(); Sentry.Initialize({.DatabasePath = SentryDatabasePath, .AttachmentsPath = SentryAttachmentPath, diff --git a/src/zenutil/commandlineoptions.cpp b/src/zenutil/commandlineoptions.cpp deleted file mode 100644 index d94564843..000000000 --- a/src/zenutil/commandlineoptions.cpp +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include - -#include -#include - -#include - -#if ZEN_WITH_TESTS -# include -#endif // ZEN_WITH_TESTS - -#ifndef CXXOPTS_HAS_FILESYSTEM -void -cxxopts::values::parse_value(const std::string& text, std::filesystem::path& value) -{ - value = zen::StringToPath(text); -} -#endif - -namespace zen { - -std::vector -ParseCommandLine(std::string_view CommandLine) -{ - auto IsWhitespaceOrEnd = [](std::string_view CommandLine, std::string::size_type Pos) { - if (Pos == CommandLine.length()) - { - return true; - } - if (CommandLine[Pos] == ' ') - { - return true; - } - return false; - }; - - bool IsParsingArg = false; - bool IsInQuote = false; - - std::string::size_type Pos = 0; - std::string::size_type ArgStart = 0; - std::vector Args; - while (Pos < CommandLine.length()) - { - if (IsInQuote) - { - if (CommandLine[Pos] == '"' && IsWhitespaceOrEnd(CommandLine, Pos + 1)) - { - Args.push_back(std::string(CommandLine.substr(ArgStart, Pos - ArgStart + 1))); - Pos++; - IsInQuote = false; - IsParsingArg = false; - } - else - { - Pos++; - } - } - else if (IsParsingArg) - { - ZEN_ASSERT(Pos > ArgStart); - if (CommandLine[Pos] == ' ') - { - Args.push_back(std::string(CommandLine.substr(ArgStart, Pos - ArgStart))); - Pos++; - IsParsingArg = false; - } - else if (CommandLine[Pos] == '"') - { - IsInQuote = true; - Pos++; - } - else - { - Pos++; - } - } - else if (CommandLine[Pos] == '"') - { - IsInQuote = true; - IsParsingArg = true; - ArgStart = Pos; - Pos++; - } - else if (CommandLine[Pos] != ' ') - { - IsParsingArg = true; - ArgStart = Pos; - Pos++; - } - else - { - Pos++; - } - } - if (IsParsingArg) - { - ZEN_ASSERT(Pos > ArgStart); - Args.push_back(std::string(CommandLine.substr(ArgStart))); - } - - return Args; -} - -std::vector -StripCommandlineQuotes(std::vector& InOutArgs) -{ - std::vector RawArgs; - RawArgs.reserve(InOutArgs.size()); - for (std::string& Arg : InOutArgs) - { - std::string::size_type EscapedQuotePos = Arg.find("\\\"", 1); - while (EscapedQuotePos != std::string::npos && Arg.rfind('\"', EscapedQuotePos - 1) != std::string::npos) - { - Arg.erase(EscapedQuotePos, 1); - EscapedQuotePos = Arg.find("\\\"", EscapedQuotePos); - } - - if (Arg.starts_with("\"")) - { - if (Arg.find('"', 1) == Arg.length() - 1) - { - Arg = Arg.substr(1, Arg.length() - 2); - } - } - else if (Arg.ends_with("\"")) - { - std::string::size_type EqualSign = Arg.find("=", 1); - if (EqualSign != std::string::npos && Arg[EqualSign + 1] == '\"') - { - Arg = Arg.substr(0, EqualSign + 1) + Arg.substr(EqualSign + 2, Arg.length() - (EqualSign + 2) - 1); - } - } - RawArgs.push_back(const_cast(Arg.c_str())); - } - return RawArgs; -} - -std::filesystem::path -StringToPath(const std::string_view& Path) -{ - std::string_view UnquotedPath = RemoveQuotes(Path); - - if (UnquotedPath.ends_with('/') || UnquotedPath.ends_with('\\') || UnquotedPath.ends_with(std::filesystem::path::preferred_separator)) - { - UnquotedPath = UnquotedPath.substr(0, UnquotedPath.length() - 1); - } - - return std::filesystem::path(UnquotedPath).make_preferred(); -} - -std::string_view -RemoveQuotes(const std::string_view& Arg) -{ - if (Arg.length() > 2) - { - if (Arg[0] == '"' && Arg[Arg.length() - 1] == '"') - { - return Arg.substr(1, Arg.length() - 2); - } - } - return Arg; -} - -CommandLineConverter::CommandLineConverter(int& argc, char**& argv) -{ -#if ZEN_PLATFORM_WINDOWS - LPWSTR RawCommandLine = GetCommandLineW(); - std::string CommandLine = WideToUtf8(RawCommandLine); - Args = ParseCommandLine(CommandLine); -#else - Args.reserve(argc); - for (int I = 0; I < argc; I++) - { - std::string Arg(argv[I]); - if ((!Arg.empty()) && (Arg != " ")) - { - Args.emplace_back(std::move(Arg)); - } - } -#endif - RawArgs = StripCommandlineQuotes(Args); - - argc = static_cast(RawArgs.size()); - argv = RawArgs.data(); -} - -#if ZEN_WITH_TESTS - -void -commandlineoptions_forcelink() -{ -} - -TEST_CASE("CommandLine") -{ - std::vector v1 = ParseCommandLine("c:\\my\\exe.exe \"quoted arg\" \"one\",two,\"three\\\""); - CHECK_EQ(v1[0], "c:\\my\\exe.exe"); - CHECK_EQ(v1[1], "\"quoted arg\""); - CHECK_EQ(v1[2], "\"one\",two,\"three\\\""); - - std::vector v2 = ParseCommandLine( - "--tracehost 127.0.0.1 builds download --url=https://jupiter.devtools.epicgames.com --namespace=ue.oplog " - "--bucket=citysample.packaged-build.fortnite-main.windows \"c:\\just\\a\\path\" " - "--access-token-path=\"C:\\Users\\dan.engelbrecht\\jupiter-token.json\" \"D:\\Dev\\Spaced Folder\\Target\\\" " - "--alt-path=\"D:\\Dev\\Spaced Folder2\\Target\\\" 07dn23ifiwesnvoasjncasab --build-part-name win64,linux,ps5"); - - std::vector v2Stripped = StripCommandlineQuotes(v2); - CHECK_EQ(v2Stripped[0], std::string("--tracehost")); - CHECK_EQ(v2Stripped[1], std::string("127.0.0.1")); - CHECK_EQ(v2Stripped[2], std::string("builds")); - CHECK_EQ(v2Stripped[3], std::string("download")); - CHECK_EQ(v2Stripped[4], std::string("--url=https://jupiter.devtools.epicgames.com")); - CHECK_EQ(v2Stripped[5], std::string("--namespace=ue.oplog")); - CHECK_EQ(v2Stripped[6], std::string("--bucket=citysample.packaged-build.fortnite-main.windows")); - CHECK_EQ(v2Stripped[7], std::string("c:\\just\\a\\path")); - CHECK_EQ(v2Stripped[8], std::string("--access-token-path=C:\\Users\\dan.engelbrecht\\jupiter-token.json")); - CHECK_EQ(v2Stripped[9], std::string("D:\\Dev\\Spaced Folder\\Target")); - CHECK_EQ(v2Stripped[10], std::string("--alt-path=D:\\Dev\\Spaced Folder2\\Target")); - CHECK_EQ(v2Stripped[11], std::string("07dn23ifiwesnvoasjncasab")); - CHECK_EQ(v2Stripped[12], std::string("--build-part-name")); - CHECK_EQ(v2Stripped[13], std::string("win64,linux,ps5")); - - std::vector v3 = ParseCommandLine( - "--tracehost \"127.0.0.1\" builds download --url=\"https://jupiter.devtools.epicgames.com\" --build-part-name=\"win64\""); - std::vector v3Stripped = StripCommandlineQuotes(v3); - - CHECK_EQ(v3Stripped[0], std::string("--tracehost")); - CHECK_EQ(v3Stripped[1], std::string("127.0.0.1")); - CHECK_EQ(v3Stripped[2], std::string("builds")); - CHECK_EQ(v3Stripped[3], std::string("download")); - CHECK_EQ(v3Stripped[4], std::string("--url=https://jupiter.devtools.epicgames.com")); - CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); -} - -#endif -} // namespace zen diff --git a/src/zenutil/config/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp new file mode 100644 index 000000000..84c718ecc --- /dev/null +++ b/src/zenutil/config/commandlineoptions.cpp @@ -0,0 +1,239 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include + +#include + +#if ZEN_WITH_TESTS +# include +#endif // ZEN_WITH_TESTS + +#ifndef CXXOPTS_HAS_FILESYSTEM +void +cxxopts::values::parse_value(const std::string& text, std::filesystem::path& value) +{ + value = zen::StringToPath(text); +} +#endif + +namespace zen { + +std::vector +ParseCommandLine(std::string_view CommandLine) +{ + auto IsWhitespaceOrEnd = [](std::string_view CommandLine, std::string::size_type Pos) { + if (Pos == CommandLine.length()) + { + return true; + } + if (CommandLine[Pos] == ' ') + { + return true; + } + return false; + }; + + bool IsParsingArg = false; + bool IsInQuote = false; + + std::string::size_type Pos = 0; + std::string::size_type ArgStart = 0; + std::vector Args; + while (Pos < CommandLine.length()) + { + if (IsInQuote) + { + if (CommandLine[Pos] == '"' && IsWhitespaceOrEnd(CommandLine, Pos + 1)) + { + Args.push_back(std::string(CommandLine.substr(ArgStart, Pos - ArgStart + 1))); + Pos++; + IsInQuote = false; + IsParsingArg = false; + } + else + { + Pos++; + } + } + else if (IsParsingArg) + { + ZEN_ASSERT(Pos > ArgStart); + if (CommandLine[Pos] == ' ') + { + Args.push_back(std::string(CommandLine.substr(ArgStart, Pos - ArgStart))); + Pos++; + IsParsingArg = false; + } + else if (CommandLine[Pos] == '"') + { + IsInQuote = true; + Pos++; + } + else + { + Pos++; + } + } + else if (CommandLine[Pos] == '"') + { + IsInQuote = true; + IsParsingArg = true; + ArgStart = Pos; + Pos++; + } + else if (CommandLine[Pos] != ' ') + { + IsParsingArg = true; + ArgStart = Pos; + Pos++; + } + else + { + Pos++; + } + } + if (IsParsingArg) + { + ZEN_ASSERT(Pos > ArgStart); + Args.push_back(std::string(CommandLine.substr(ArgStart))); + } + + return Args; +} + +std::vector +StripCommandlineQuotes(std::vector& InOutArgs) +{ + std::vector RawArgs; + RawArgs.reserve(InOutArgs.size()); + for (std::string& Arg : InOutArgs) + { + std::string::size_type EscapedQuotePos = Arg.find("\\\"", 1); + while (EscapedQuotePos != std::string::npos && Arg.rfind('\"', EscapedQuotePos - 1) != std::string::npos) + { + Arg.erase(EscapedQuotePos, 1); + EscapedQuotePos = Arg.find("\\\"", EscapedQuotePos); + } + + if (Arg.starts_with("\"")) + { + if (Arg.find('"', 1) == Arg.length() - 1) + { + Arg = Arg.substr(1, Arg.length() - 2); + } + } + else if (Arg.ends_with("\"")) + { + std::string::size_type EqualSign = Arg.find("=", 1); + if (EqualSign != std::string::npos && Arg[EqualSign + 1] == '\"') + { + Arg = Arg.substr(0, EqualSign + 1) + Arg.substr(EqualSign + 2, Arg.length() - (EqualSign + 2) - 1); + } + } + RawArgs.push_back(const_cast(Arg.c_str())); + } + return RawArgs; +} + +std::filesystem::path +StringToPath(const std::string_view& Path) +{ + std::string_view UnquotedPath = RemoveQuotes(Path); + + if (UnquotedPath.ends_with('/') || UnquotedPath.ends_with('\\') || UnquotedPath.ends_with(std::filesystem::path::preferred_separator)) + { + UnquotedPath = UnquotedPath.substr(0, UnquotedPath.length() - 1); + } + + return std::filesystem::path(UnquotedPath).make_preferred(); +} + +std::string_view +RemoveQuotes(const std::string_view& Arg) +{ + if (Arg.length() > 2) + { + if (Arg[0] == '"' && Arg[Arg.length() - 1] == '"') + { + return Arg.substr(1, Arg.length() - 2); + } + } + return Arg; +} + +CommandLineConverter::CommandLineConverter(int& argc, char**& argv) +{ +#if ZEN_PLATFORM_WINDOWS + LPWSTR RawCommandLine = GetCommandLineW(); + std::string CommandLine = WideToUtf8(RawCommandLine); + Args = ParseCommandLine(CommandLine); +#else + Args.reserve(argc); + for (int I = 0; I < argc; I++) + { + std::string Arg(argv[I]); + if ((!Arg.empty()) && (Arg != " ")) + { + Args.emplace_back(std::move(Arg)); + } + } +#endif + RawArgs = StripCommandlineQuotes(Args); + + argc = static_cast(RawArgs.size()); + argv = RawArgs.data(); +} + +#if ZEN_WITH_TESTS + +void +commandlineoptions_forcelink() +{ +} + +TEST_CASE("CommandLine") +{ + std::vector v1 = ParseCommandLine("c:\\my\\exe.exe \"quoted arg\" \"one\",two,\"three\\\""); + CHECK_EQ(v1[0], "c:\\my\\exe.exe"); + CHECK_EQ(v1[1], "\"quoted arg\""); + CHECK_EQ(v1[2], "\"one\",two,\"three\\\""); + + std::vector v2 = ParseCommandLine( + "--tracehost 127.0.0.1 builds download --url=https://jupiter.devtools.epicgames.com --namespace=ue.oplog " + "--bucket=citysample.packaged-build.fortnite-main.windows \"c:\\just\\a\\path\" " + "--access-token-path=\"C:\\Users\\dan.engelbrecht\\jupiter-token.json\" \"D:\\Dev\\Spaced Folder\\Target\\\" " + "--alt-path=\"D:\\Dev\\Spaced Folder2\\Target\\\" 07dn23ifiwesnvoasjncasab --build-part-name win64,linux,ps5"); + + std::vector v2Stripped = StripCommandlineQuotes(v2); + CHECK_EQ(v2Stripped[0], std::string("--tracehost")); + CHECK_EQ(v2Stripped[1], std::string("127.0.0.1")); + CHECK_EQ(v2Stripped[2], std::string("builds")); + CHECK_EQ(v2Stripped[3], std::string("download")); + CHECK_EQ(v2Stripped[4], std::string("--url=https://jupiter.devtools.epicgames.com")); + CHECK_EQ(v2Stripped[5], std::string("--namespace=ue.oplog")); + CHECK_EQ(v2Stripped[6], std::string("--bucket=citysample.packaged-build.fortnite-main.windows")); + CHECK_EQ(v2Stripped[7], std::string("c:\\just\\a\\path")); + CHECK_EQ(v2Stripped[8], std::string("--access-token-path=C:\\Users\\dan.engelbrecht\\jupiter-token.json")); + CHECK_EQ(v2Stripped[9], std::string("D:\\Dev\\Spaced Folder\\Target")); + CHECK_EQ(v2Stripped[10], std::string("--alt-path=D:\\Dev\\Spaced Folder2\\Target")); + CHECK_EQ(v2Stripped[11], std::string("07dn23ifiwesnvoasjncasab")); + CHECK_EQ(v2Stripped[12], std::string("--build-part-name")); + CHECK_EQ(v2Stripped[13], std::string("win64,linux,ps5")); + + std::vector v3 = ParseCommandLine( + "--tracehost \"127.0.0.1\" builds download --url=\"https://jupiter.devtools.epicgames.com\" --build-part-name=\"win64\""); + std::vector v3Stripped = StripCommandlineQuotes(v3); + + CHECK_EQ(v3Stripped[0], std::string("--tracehost")); + CHECK_EQ(v3Stripped[1], std::string("127.0.0.1")); + CHECK_EQ(v3Stripped[2], std::string("builds")); + CHECK_EQ(v3Stripped[3], std::string("download")); + CHECK_EQ(v3Stripped[4], std::string("--url=https://jupiter.devtools.epicgames.com")); + CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); +} + +#endif +} // namespace zen diff --git a/src/zenutil/config/environmentoptions.cpp b/src/zenutil/config/environmentoptions.cpp new file mode 100644 index 000000000..fb7f71706 --- /dev/null +++ b/src/zenutil/config/environmentoptions.cpp @@ -0,0 +1,84 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include + +namespace zen { + +EnvironmentOptions::StringOption::StringOption(std::string& Value) : RefValue(Value) +{ +} +void +EnvironmentOptions::StringOption::Parse(std::string_view Value) +{ + RefValue = std::string(Value); +} + +EnvironmentOptions::FilePathOption::FilePathOption(std::filesystem::path& Value) : RefValue(Value) +{ +} +void +EnvironmentOptions::FilePathOption::Parse(std::string_view Value) +{ + RefValue = MakeSafeAbsolutePath(Value); +} + +EnvironmentOptions::BoolOption::BoolOption(bool& Value) : RefValue(Value) +{ +} +void +EnvironmentOptions::BoolOption::Parse(std::string_view Value) +{ + const std::string Lower = ToLower(Value); + if (Lower == "true" || Lower == "y" || Lower == "yes") + { + RefValue = true; + } + else if (Lower == "false" || Lower == "n" || Lower == "no") + { + RefValue = false; + } +} + +std::shared_ptr +EnvironmentOptions::MakeOption(std::string& Value) +{ + return std::make_shared(Value); +} + +std::shared_ptr +EnvironmentOptions::MakeOption(std::filesystem::path& Value) +{ + return std::make_shared(Value); +} + +std::shared_ptr +EnvironmentOptions::MakeOption(bool& Value) +{ + return std::make_shared(Value); +} + +EnvironmentOptions::EnvironmentOptions() +{ +} + +void +EnvironmentOptions::Parse(const cxxopts::ParseResult& CmdLineResult) +{ + for (auto& It : OptionMap) + { + std::string_view EnvName = It.first; + const Option& Opt = It.second; + if (CmdLineResult.count(Opt.CommandLineOptionName) == 0) + { + std::string EnvValue = GetEnvVariable(EnvName); + if (!EnvValue.empty()) + { + Opt.Value->Parse(EnvValue); + } + } + } +} + +} // namespace zen diff --git a/src/zenutil/config/loggingconfig.cpp b/src/zenutil/config/loggingconfig.cpp new file mode 100644 index 000000000..9ec816b1b --- /dev/null +++ b/src/zenutil/config/loggingconfig.cpp @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/config/loggingconfig.h" + +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +void +ZenLoggingCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) +{ + // clang-format off + options.add_options("logging") + ("abslog", "Path to log file", cxxopts::value(m_AbsLogFile)) + ("log-id", "Specify id for adding context to log output", cxxopts::value(LoggingConfig.LogId)) + ("quiet", "Configure console logger output to level WARN", cxxopts::value(LoggingConfig.QuietConsole)->default_value("false")) + ("noconsole", "Disable console logging", cxxopts::value(LoggingConfig.NoConsoleOutput)->default_value("false")) + ("log-trace", "Change selected loggers to level TRACE", cxxopts::value(LoggingConfig.Loggers[logging::level::Trace])) + ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value(LoggingConfig.Loggers[logging::level::Debug])) + ("log-info", "Change selected loggers to level INFO", cxxopts::value(LoggingConfig.Loggers[logging::level::Info])) + ("log-warn", "Change selected loggers to level WARN", cxxopts::value(LoggingConfig.Loggers[logging::level::Warn])) + ("log-error", "Change selected loggers to level ERROR", cxxopts::value(LoggingConfig.Loggers[logging::level::Err])) + ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value(LoggingConfig.Loggers[logging::level::Critical])) + ("log-off", "Change selected loggers to level OFF", cxxopts::value(LoggingConfig.Loggers[logging::level::Off])) + ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value(LoggingConfig.OtelEndpointUri)) + ; + // clang-format on +} + +void +ZenLoggingCmdLineOptions::ApplyOptions(ZenLoggingConfig& LoggingConfig) +{ + LoggingConfig.AbsLogFile = MakeSafeAbsolutePath(m_AbsLogFile); +} + +void +ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) +{ + ZEN_UNUSED(options); + + if (LoggingConfig.QuietConsole) + { + bool HasExplicitConsoleLevel = false; + for (int i = 0; i < logging::level::LogLevelCount; ++i) + { + if (LoggingConfig.Loggers[i].find("console") != std::string::npos) + { + HasExplicitConsoleLevel = true; + break; + } + } + + if (!HasExplicitConsoleLevel) + { + std::string& WarnLoggers = LoggingConfig.Loggers[logging::level::Warn]; + if (!WarnLoggers.empty()) + { + WarnLoggers += ","; + } + WarnLoggers += "console"; + } + } + + for (int i = 0; i < logging::level::LogLevelCount; ++i) + { + logging::ConfigureLogLevels(logging::level::LogLevel(i), LoggingConfig.Loggers[i]); + } + logging::RefreshLogLevels(); +} + +} // namespace zen diff --git a/src/zenutil/environmentoptions.cpp b/src/zenutil/environmentoptions.cpp deleted file mode 100644 index ee40086c1..000000000 --- a/src/zenutil/environmentoptions.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include - -#include - -namespace zen { - -EnvironmentOptions::StringOption::StringOption(std::string& Value) : RefValue(Value) -{ -} -void -EnvironmentOptions::StringOption::Parse(std::string_view Value) -{ - RefValue = std::string(Value); -} - -EnvironmentOptions::FilePathOption::FilePathOption(std::filesystem::path& Value) : RefValue(Value) -{ -} -void -EnvironmentOptions::FilePathOption::Parse(std::string_view Value) -{ - RefValue = MakeSafeAbsolutePath(Value); -} - -EnvironmentOptions::BoolOption::BoolOption(bool& Value) : RefValue(Value) -{ -} -void -EnvironmentOptions::BoolOption::Parse(std::string_view Value) -{ - const std::string Lower = ToLower(Value); - if (Lower == "true" || Lower == "y" || Lower == "yes") - { - RefValue = true; - } - else if (Lower == "false" || Lower == "n" || Lower == "no") - { - RefValue = false; - } -} - -std::shared_ptr -EnvironmentOptions::MakeOption(std::string& Value) -{ - return std::make_shared(Value); -} - -std::shared_ptr -EnvironmentOptions::MakeOption(std::filesystem::path& Value) -{ - return std::make_shared(Value); -} - -std::shared_ptr -EnvironmentOptions::MakeOption(bool& Value) -{ - return std::make_shared(Value); -} - -EnvironmentOptions::EnvironmentOptions() -{ -} - -void -EnvironmentOptions::Parse(const cxxopts::ParseResult& CmdLineResult) -{ - for (auto& It : OptionMap) - { - std::string_view EnvName = It.first; - const Option& Opt = It.second; - if (CmdLineResult.count(Opt.CommandLineOptionName) == 0) - { - std::string EnvValue = GetEnvVariable(EnvName); - if (!EnvValue.empty()) - { - Opt.Value->Parse(EnvValue); - } - } - } -} - -} // namespace zen diff --git a/src/zenutil/include/zenutil/commandlineoptions.h b/src/zenutil/include/zenutil/commandlineoptions.h deleted file mode 100644 index 01cceedb1..000000000 --- a/src/zenutil/include/zenutil/commandlineoptions.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include - -ZEN_THIRD_PARTY_INCLUDES_START - -namespace cxxopts::values { -// We declare this specialization before including cxxopts to make it stick -void parse_value(const std::string& text, std::filesystem::path& value); -} // namespace cxxopts::values - -#include -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen { - -std::vector ParseCommandLine(std::string_view CommandLine); -std::vector StripCommandlineQuotes(std::vector& InOutArgs); -std::filesystem::path StringToPath(const std::string_view& Path); -std::string_view RemoveQuotes(const std::string_view& Arg); - -class CommandLineConverter -{ -public: - CommandLineConverter(int& argc, char**& argv); - - int ArgC = 0; - char** ArgV = nullptr; - -private: - std::vector Args; - std::vector RawArgs; -}; - -void commandlineoptions_forcelink(); // internal - -} // namespace zen diff --git a/src/zenutil/include/zenutil/config/commandlineoptions.h b/src/zenutil/include/zenutil/config/commandlineoptions.h new file mode 100644 index 000000000..01cceedb1 --- /dev/null +++ b/src/zenutil/include/zenutil/config/commandlineoptions.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START + +namespace cxxopts::values { +// We declare this specialization before including cxxopts to make it stick +void parse_value(const std::string& text, std::filesystem::path& value); +} // namespace cxxopts::values + +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +std::vector ParseCommandLine(std::string_view CommandLine); +std::vector StripCommandlineQuotes(std::vector& InOutArgs); +std::filesystem::path StringToPath(const std::string_view& Path); +std::string_view RemoveQuotes(const std::string_view& Arg); + +class CommandLineConverter +{ +public: + CommandLineConverter(int& argc, char**& argv); + + int ArgC = 0; + char** ArgV = nullptr; + +private: + std::vector Args; + std::vector RawArgs; +}; + +void commandlineoptions_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/config/environmentoptions.h b/src/zenutil/include/zenutil/config/environmentoptions.h new file mode 100644 index 000000000..1ecdf591a --- /dev/null +++ b/src/zenutil/include/zenutil/config/environmentoptions.h @@ -0,0 +1,92 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +namespace zen { + +class EnvironmentOptions +{ +public: + class OptionValue + { + public: + virtual void Parse(std::string_view Value) = 0; + + virtual ~OptionValue() {} + }; + + class StringOption : public OptionValue + { + public: + explicit StringOption(std::string& Value); + virtual void Parse(std::string_view Value) override; + std::string& RefValue; + }; + + class FilePathOption : public OptionValue + { + public: + explicit FilePathOption(std::filesystem::path& Value); + virtual void Parse(std::string_view Value) override; + std::filesystem::path& RefValue; + }; + + class BoolOption : public OptionValue + { + public: + explicit BoolOption(bool& Value); + virtual void Parse(std::string_view Value); + bool& RefValue; + }; + + template + class NumberOption : public OptionValue + { + public: + explicit NumberOption(T& Value) : RefValue(Value) {} + virtual void Parse(std::string_view Value) override + { + if (std::optional OptionalValue = ParseInt(Value); OptionalValue.has_value()) + { + RefValue = OptionalValue.value(); + } + } + T& RefValue; + }; + + struct Option + { + std::string CommandLineOptionName; + std::shared_ptr Value; + }; + + std::shared_ptr MakeOption(std::string& Value); + std::shared_ptr MakeOption(std::filesystem::path& Value); + + template + std::shared_ptr MakeOption(T& Value) + { + return std::make_shared>(Value); + }; + + std::shared_ptr MakeOption(bool& Value); + + template + void AddOption(std::string_view EnvName, T& Value, std::string_view CommandLineOptionName = "") + { + OptionMap.insert_or_assign(std::string(EnvName), + Option{.CommandLineOptionName = std::string(CommandLineOptionName), .Value = MakeOption(Value)}); + }; + + EnvironmentOptions(); + + void Parse(const cxxopts::ParseResult& CmdLineResult); + +private: + std::unordered_map OptionMap; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/config/loggingconfig.h b/src/zenutil/include/zenutil/config/loggingconfig.h new file mode 100644 index 000000000..6d6f64b30 --- /dev/null +++ b/src/zenutil/include/zenutil/config/loggingconfig.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include + +namespace cxxopts { +class Options; +} + +namespace zen { + +struct ZenLoggingConfig +{ + bool NoConsoleOutput = false; // Control default use of stdout for diagnostics + bool QuietConsole = false; // Configure console logger output to level WARN + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::string Loggers[logging::level::LogLevelCount]; + std::string LogId; // Id for tagging log output + std::string OtelEndpointUri; // OpenTelemetry endpoint URI +}; + +void ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig); + +class ZenLoggingCmdLineOptions +{ +public: + void AddCliOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig); + void ApplyOptions(ZenLoggingConfig& LoggingConfig); + +private: + std::string m_AbsLogFile; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/environmentoptions.h b/src/zenutil/include/zenutil/environmentoptions.h deleted file mode 100644 index 7418608e4..000000000 --- a/src/zenutil/include/zenutil/environmentoptions.h +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include - -namespace zen { - -class EnvironmentOptions -{ -public: - class OptionValue - { - public: - virtual void Parse(std::string_view Value) = 0; - - virtual ~OptionValue() {} - }; - - class StringOption : public OptionValue - { - public: - explicit StringOption(std::string& Value); - virtual void Parse(std::string_view Value) override; - std::string& RefValue; - }; - - class FilePathOption : public OptionValue - { - public: - explicit FilePathOption(std::filesystem::path& Value); - virtual void Parse(std::string_view Value) override; - std::filesystem::path& RefValue; - }; - - class BoolOption : public OptionValue - { - public: - explicit BoolOption(bool& Value); - virtual void Parse(std::string_view Value); - bool& RefValue; - }; - - template - class NumberOption : public OptionValue - { - public: - explicit NumberOption(T& Value) : RefValue(Value) {} - virtual void Parse(std::string_view Value) override - { - if (std::optional OptionalValue = ParseInt(Value); OptionalValue.has_value()) - { - RefValue = OptionalValue.value(); - } - } - T& RefValue; - }; - - struct Option - { - std::string CommandLineOptionName; - std::shared_ptr Value; - }; - - std::shared_ptr MakeOption(std::string& Value); - std::shared_ptr MakeOption(std::filesystem::path& Value); - - template - std::shared_ptr MakeOption(T& Value) - { - return std::make_shared>(Value); - }; - - std::shared_ptr MakeOption(bool& Value); - - template - void AddOption(std::string_view EnvName, T& Value, std::string_view CommandLineOptionName = "") - { - OptionMap.insert_or_assign(std::string(EnvName), - Option{.CommandLineOptionName = std::string(CommandLineOptionName), .Value = MakeOption(Value)}); - }; - - EnvironmentOptions(); - - void Parse(const cxxopts::ParseResult& CmdLineResult); - -private: - std::unordered_map OptionMap; -}; - -} // namespace zen diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 51c1ee72e..291dbeadd 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -5,7 +5,7 @@ #if ZEN_WITH_TESTS # include -# include +# include # include namespace zen { -- cgit v1.2.3 From 58e1e1ef2deedc49b3e88db57c110b88a39e21da Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 13 Feb 2026 14:47:24 +0100 Subject: spelling fixes (#755) --- src/zen/cmds/builds_cmd.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index f4edb65ab..59b209384 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -3016,7 +3016,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_BlobHash.length() != IoHash::StringLength) { throw OptionParseException( - fmt::format("'--blob-hash' ('{}') is malfomed, it must be {} characters long", m_BlobHash, IoHash::StringLength), + fmt::format("'--blob-hash' ('{}') is malformed, it must be {} characters long", m_BlobHash, IoHash::StringLength), SubOption->help()); } @@ -3033,7 +3033,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_BuildId.length() != Oid::StringLength) { throw OptionParseException( - fmt::format("'--build-id' ('{}') is malfomed, it must be {} characters long", m_BuildId, Oid::StringLength), + fmt::format("'--build-id' ('{}') is malformed, it must be {} characters long", m_BuildId, Oid::StringLength), SubOption->help()); } else if (Oid BuildId = Oid::FromHexString(m_BuildId); BuildId == Oid::Zero) -- cgit v1.2.3 From df97b6b2abcc8ce13b1d63e3d2cf27c3bd841768 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 13 Feb 2026 15:19:51 +0100 Subject: add foundation for http password protection (#756) --- .../include/zenhttp/security/passwordsecurity.h | 52 +++++ src/zenhttp/security/passwordsecurity.cpp | 221 +++++++++++++++++++++ src/zenhttp/zenhttp.cpp | 2 + 3 files changed, 275 insertions(+) create mode 100644 src/zenhttp/include/zenhttp/security/passwordsecurity.h create mode 100644 src/zenhttp/security/passwordsecurity.cpp (limited to 'src') diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h new file mode 100644 index 000000000..026c2865b --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +struct PasswordSecurityConfiguration +{ + std::string Password; // "password" + bool ProtectMachineLocalRequests = false; // "protect-machine-local-requests" + std::vector UnprotectedUris; // "unprotected-urls" +}; + +class PasswordSecurity +{ +public: + PasswordSecurity(const PasswordSecurityConfiguration& Config); + + [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; } + [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; } + [[nodiscard]] bool IsUnprotectedUri(std::string_view Uri) const; + + bool IsAllowed(std::string_view Password, std::string_view Uri, bool IsMachineLocalRequest); + +private: + const PasswordSecurityConfiguration m_Config; + tsl::robin_map m_UnprotectedUrlHashes; +}; + +/** + * Expected format (Json) + * { + * "password\": \"1234\", + * "protect-machine-local-requests\": false, + * "unprotected-urls\": [ + * "/health\", + * "/health/info\", + * "/health/version\" + * ] + * } + */ +PasswordSecurityConfiguration ReadPasswordSecurityConfiguration(CbObjectView ConfigObject); + +void passwordsecurity_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp new file mode 100644 index 000000000..37be9a018 --- /dev/null +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -0,0 +1,221 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurity.h" +#include +#include +#include + +#if ZEN_WITH_TESTS +# include +# include +#endif // ZEN_WITH_TESTS + +namespace zen { +using namespace std::literals; + +PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) : m_Config(Config) +{ + m_UnprotectedUrlHashes.reserve(m_Config.UnprotectedUris.size()); + for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) + { + const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; + if (auto Result = m_UnprotectedUrlHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) + { + throw std::runtime_error(fmt::format( + "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", + Index + 1, + UnprotectedUri, + Result.first->second + 1, + m_Config.UnprotectedUris[Result.first->second])); + } + } +} + +bool +PasswordSecurity::IsUnprotectedUri(std::string_view Uri) const +{ + if (!m_Config.UnprotectedUris.empty()) + { + uint32_t UriHash = HashStringDjb2(Uri); + if (auto It = m_UnprotectedUrlHashes.find(UriHash); It != m_UnprotectedUrlHashes.end()) + { + if (m_Config.UnprotectedUris[It->second] == Uri) + { + return true; + } + } + } + return false; +} + +PasswordSecurityConfiguration +ReadPasswordSecurityConfiguration(CbObjectView ConfigObject) +{ + return PasswordSecurityConfiguration{ + .Password = std::string(ConfigObject["password"sv].AsString()), + .ProtectMachineLocalRequests = ConfigObject["protect-machine-local-requests"sv].AsBool(), + .UnprotectedUris = compactbinary_helpers::ReadArray("unprotected-urls"sv, ConfigObject)}; +} + +bool +PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, bool IsMachineLocalRequest) +{ + if (IsUnprotectedUri(Uri)) + { + return true; + } + if (!ProtectMachineLocalRequests() && IsMachineLocalRequest) + { + return true; + } + if (Password().empty()) + { + return true; + } + if (Password() == InPassword) + { + return true; + } + return false; +} + +#if ZEN_WITH_TESTS + +TEST_CASE("passwordsecurity.readconfig") +{ + auto ReadConfigJson = [](std::string_view Json) { + std::string JsonError; + CbObject Config = LoadCompactBinaryFromJson(Json, JsonError).AsObject(); + REQUIRE(JsonError.empty()); + return Config; + }; + + { + PasswordSecurityConfiguration EmptyConfig = ReadPasswordSecurityConfiguration(CbObject()); + CHECK(EmptyConfig.Password.empty()); + CHECK(!EmptyConfig.ProtectMachineLocalRequests); + CHECK(EmptyConfig.UnprotectedUris.empty()); + } + + { + const std::string_view SimpleConfigJson = + "{\n" + " \"password\": \"1234\"\n" + "}"; + PasswordSecurityConfiguration SimpleConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(SimpleConfigJson)); + CHECK(SimpleConfig.Password == "1234"); + CHECK(!SimpleConfig.ProtectMachineLocalRequests); + CHECK(SimpleConfig.UnprotectedUris.empty()); + } + + { + const std::string_view ComplexConfigJson = + "{\n" + " \"password\": \"1234\",\n" + " \"protect-machine-local-requests\": true,\n" + " \"unprotected-urls\": [\n" + " \"/health\",\n" + " \"/health/info\",\n" + " \"/health/version\"\n" + " ]\n" + "}"; + PasswordSecurityConfiguration ComplexConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(ComplexConfigJson)); + CHECK(ComplexConfig.Password == "1234"); + CHECK(ComplexConfig.ProtectMachineLocalRequests); + CHECK(ComplexConfig.UnprotectedUris == std::vector({"/health", "/health/info", "/health/version"})); + } +} + +TEST_CASE("passwordsecurity.allowanything") +{ + PasswordSecurity Anything({}); + CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowalllocal") +{ + PasswordSecurity AllLocal({.Password = "123456"}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowonlypassword") +{ + PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomeexternaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomelocaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.conflictingunprotecteduris") +{ + try + { + PasswordSecurity AllLocal({.Password = "123456", + .ProtectMachineLocalRequests = true, + .UnprotectedUris = std::vector({"/free/access", "/free/access"})}); + CHECK(false); + } + catch (const std::runtime_error& Ex) + { + CHECK_EQ(Ex.what(), + std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with " + "uri #1 ('/free/access')")); + } +} +void +passwordsecurity_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index a2679f92e..0b5408453 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -7,6 +7,7 @@ # include # include # include +# include namespace zen { @@ -16,6 +17,7 @@ zenhttp_forcelinktests() http_forcelink(); httpclient_forcelink(); forcelink_packageformat(); + passwordsecurity_forcelink(); } } // namespace zen -- cgit v1.2.3 From 0697a2facd63908b45495fa0a1e94c982e34f052 Mon Sep 17 00:00:00 2001 From: zousar Date: Sat, 14 Feb 2026 23:51:54 -0700 Subject: Enhance dependencies to include soft and hard deps --- src/zenserver/frontend/html/pages/entry.js | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 08589b090..212686e42 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -155,7 +155,7 @@ export class Page extends ZenPage if (Object.keys(tree).length != 0) { - const sub_section = section.add_section("deps"); + const sub_section = section.add_section("dependencies"); this._build_deps(sub_section, tree); } } @@ -271,16 +271,18 @@ export class Page extends ZenPage for (const field of pkgst_entry) { const field_name = field.get_name(); - if (!field_name.endsWith("importedpackageids")) - continue; - - var dep_name = field_name.slice(0, -18); - if (dep_name.length == 0) - dep_name = "imported"; - - var out = tree[dep_name] = []; - for (var item of field.as_array()) - out.push(item.as_value(BigInt)); + if (field_name == "importedpackageids") + { + var out = tree["hard"] = []; + for (var item of field.as_array()) + out.push(item.as_value(BigInt)); + } + else if (field_name == "softpackagereferences") + { + var out = tree["soft"] = []; + for (var item of field.as_array()) + out.push(item.as_value(BigInt)); + } } return tree; -- cgit v1.2.3 From c40e2c7625cf6aab25862c1c18caeb8577884656 Mon Sep 17 00:00:00 2001 From: zousar Date: Sun, 15 Feb 2026 11:55:17 -0700 Subject: Restore handling for hard/soft name prefixes --- src/zenserver/frontend/html/pages/entry.js | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 212686e42..76afd3e1f 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -271,15 +271,27 @@ export class Page extends ZenPage for (const field of pkgst_entry) { const field_name = field.get_name(); - if (field_name == "importedpackageids") + if (field_name.endsWith("importedpackageids")) { - var out = tree["hard"] = []; + var dep_name = field_name.slice(0, -18); + if (dep_name.length == 0) + dep_name = "hard"; + else + dep_name = "hard." + dep_name; + + var out = tree[dep_name] = []; for (var item of field.as_array()) out.push(item.as_value(BigInt)); } - else if (field_name == "softpackagereferences") + else if (field_name.endsWith("softpackagereferences")) { - var out = tree["soft"] = []; + var dep_name = field_name.slice(0, -21); + if (dep_name.length == 0) + dep_name = "soft"; + else + dep_name = "soft." + dep_name; + + var out = tree[dep_name] = []; for (var item of field.as_array()) out.push(item.as_value(BigInt)); } -- cgit v1.2.3 From 81a6d5e29453db761d058b6418044c8cf04a167e Mon Sep 17 00:00:00 2001 From: zousar Date: Sun, 15 Feb 2026 23:44:17 -0700 Subject: Add support for listing files on oplog entries --- src/zenserver/frontend/html/pages/entry.js | 119 ++++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 76afd3e1f..26ea78142 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -76,6 +76,21 @@ export class Page extends ZenPage return null; } + _is_null_io_hash_string(io_hash) + { + if (!io_hash) + return true; + + for (let char of io_hash) + { + if (char != '0') + { + return false; + } + } + return true; + } + async _build_meta(section, entry) { var tree = {} @@ -142,30 +157,34 @@ export class Page extends ZenPage const name = entry.find("key").as_value(); var section = this.add_section(name); + var has_package_data = false; // tree { var tree = entry.find("$tree"); if (tree == undefined) tree = this._convert_legacy_to_tree(entry); - if (tree == undefined) - return this._display_unsupported(section, entry); - - delete tree["$id"]; - - if (Object.keys(tree).length != 0) + if (tree != undefined) { - const sub_section = section.add_section("dependencies"); - this._build_deps(sub_section, tree); + delete tree["$id"]; + + if (Object.keys(tree).length != 0) + { + const sub_section = section.add_section("dependencies"); + this._build_deps(sub_section, tree); + } + has_package_data = true; } } // meta + if (has_package_data) { this._build_meta(section, entry); } // data + if (has_package_data) { const sub_section = section.add_section("data"); const table = sub_section.add_widget( @@ -181,7 +200,7 @@ export class Page extends ZenPage for (const item of pkg_data.as_array()) { - var io_hash, size, raw_size, file_name; + var io_hash = undefined, size = undefined, raw_size = undefined, file_name = undefined; for (const field of item.as_object()) { if (field.is_named("data")) io_hash = field.as_value(); @@ -219,12 +238,94 @@ export class Page extends ZenPage } } + // files + var has_file_data = false; + { + const sub_section = section.add_section("files"); + const table = sub_section.add_widget( + Table, + ["name", "actions"], Table.Flag_PackRight + ); + table.id("filetable"); + for (const field_name of ["files"]) + { + var file_data = entry.find(field_name); + if (file_data == undefined) + continue; + + has_file_data = true; + + for (const item of file_data.as_array()) + { + var io_hash = undefined, cid = undefined, server_path = undefined, client_path = undefined; + for (const field of item.as_object()) + { + if (field.is_named("data")) io_hash = field.as_value(); + else if (field.is_named("id")) cid = field.as_value(); + else if (field.is_named("serverpath")) server_path = field.as_value(); + else if (field.is_named("clientpath")) client_path = field.as_value(); + } + + if (io_hash instanceof Uint8Array) + { + var ret = ""; + for (var x of io_hash) + ret += x.toString(16).padStart(2, "0"); + io_hash = ret; + } + + if (cid instanceof Uint8Array) + { + var ret = ""; + for (var x of cid) + ret += x.toString(16).padStart(2, "0"); + cid = ret; + } + + const row = table.add_row(server_path); + + var base_name = server_path.split("/").pop().split("\\").pop(); + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + if (this._is_null_io_hash_string(io_hash)) + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, cid].join("/") + ); + link.first_child().attr("download", `${cid}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-id").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, cid); + } + else + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, io_hash].join("/") + ); + link.first_child().attr("download", `${io_hash}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-hash").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, io_hash); + } + + } + } + } + // props + if (has_package_data) { const object = entry.to_js_object(); var sub_section = section.add_section("props"); sub_section.add_widget(PropTable).add_object(object); } + + if (!has_package_data && !has_file_data) + return this._display_unsupported(section, entry); } _display_unsupported(section, entry) -- cgit v1.2.3 From df806dcb92f0b5c9622586460fc86e698ca03ab6 Mon Sep 17 00:00:00 2001 From: zousar Date: Sun, 15 Feb 2026 23:44:54 -0700 Subject: Change breadcrumbs for oplogs to be more descriptive --- src/zenserver/frontend/html/pages/oplog.js | 2 +- src/zenserver/frontend/html/pages/page.js | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/oplog.js b/src/zenserver/frontend/html/pages/oplog.js index 879fc4c97..a286f8651 100644 --- a/src/zenserver/frontend/html/pages/oplog.js +++ b/src/zenserver/frontend/html/pages/oplog.js @@ -32,7 +32,7 @@ export class Page extends ZenPage this.set_title("oplog - " + oplog); - var section = this.add_section(project + " - " + oplog); + var section = this.add_section(oplog); oplog_info = await oplog_info; this._index_max = oplog_info["opcount"]; diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 9a9541904..2f9643008 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -97,7 +97,7 @@ export class ZenPage extends PageBase generate_crumbs() { - const auto_name = this.get_param("page") || "start"; + var auto_name = this.get_param("page") || "start"; if (auto_name == "start") return; @@ -114,15 +114,21 @@ export class ZenPage extends PageBase var project = this.get_param("project"); if (project != undefined) { + auto_name = project; var oplog = this.get_param("oplog"); if (oplog != undefined) { - new_crumb("project", `?page=project&project=${project}`); - if (this.get_param("opkey")) - new_crumb("oplog", `?page=oplog&project=${project}&oplog=${oplog}`); + auto_name = oplog; + new_crumb(project, `?page=project&project=${project}`); + var opkey = this.get_param("opkey") + if (opkey != undefined) + { + auto_name = opkey.split("/").pop().split("\\").pop();; + new_crumb(oplog, `?page=oplog&project=${project}&oplog=${oplog}`); + } } } - new_crumb(auto_name.toLowerCase()); + new_crumb(auto_name); } } -- cgit v1.2.3 From 74a5e2fb8dec43682e81a98c9677aef849ca7cc1 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 16 Feb 2026 15:37:13 +0100 Subject: added ResetConsoleLog (#758) also made sure log initialization calls it to ensure the console output format is retained even if the console logger was set up before logging is initialized --- src/zencore/include/zencore/logging.h | 1 + src/zencore/logging.cpp | 8 ++++++++ src/zenutil/logging.cpp | 5 +++++ 3 files changed, 14 insertions(+) (limited to 'src') diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index afbbbd3ee..74a44d028 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -31,6 +31,7 @@ void FlushLogging(); LoggerRef Default(); void SetDefault(std::string_view NewDefaultLoggerId); LoggerRef ConsoleLog(); +void ResetConsoleLog(); void SuppressConsoleLog(); LoggerRef ErrorLog(); void SetErrorLog(std::string_view LoggerId); diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index a6697c443..77e05a909 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -404,6 +404,14 @@ ConsoleLog() return *ConLogger; } +void +ResetConsoleLog() +{ + LoggerRef ConLog = ConsoleLog(); + + ConLog.SpdLogger->set_pattern("%v"); +} + void InitializeLogging() { diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging.cpp index 806b96d52..54ac30c5d 100644 --- a/src/zenutil/logging.cpp +++ b/src/zenutil/logging.cpp @@ -233,6 +233,11 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) LogOptions.LogId, std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart()))); // default to duration prefix + // If the console logger was initialized before, the above will change the output format + // so we need to reset it + + logging::ResetConsoleLog(); + if (g_FileSink) { if (LogOptions.AbsLogFile.extension() == ".json") -- cgit v1.2.3 From ccfcb14ef1b837ed6f752ae4f27e0ef88a5b18da Mon Sep 17 00:00:00 2001 From: zousar Date: Mon, 16 Feb 2026 16:39:44 -0700 Subject: Added custom page for cook.artifacts --- src/zenserver/frontend/html/pages/cookartifacts.js | 385 +++++++++++++++++++++ src/zenserver/frontend/html/pages/entry.js | 14 +- src/zenserver/frontend/html/pages/page.js | 15 +- src/zenserver/frontend/html/zen.css | 18 + 4 files changed, 428 insertions(+), 4 deletions(-) create mode 100644 src/zenserver/frontend/html/pages/cookartifacts.js (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/cookartifacts.js b/src/zenserver/frontend/html/pages/cookartifacts.js new file mode 100644 index 000000000..6c36c7f32 --- /dev/null +++ b/src/zenserver/frontend/html/pages/cookartifacts.js @@ -0,0 +1,385 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +"use strict"; + +import { ZenPage } from "./page.js" +import { Fetcher } from "../util/fetcher.js" +import { Table, Toolbar, PropTable } from "../util/widgets.js" + +//////////////////////////////////////////////////////////////////////////////// +export class Page extends ZenPage +{ + main() + { + this.set_title("cook artifacts"); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); + const artifact_hash = this.get_param("hash"); + + // Fetch the artifact content as JSON + this._artifact = new Fetcher() + .resource("prj", project, "oplog", oplog, artifact_hash + ".json") + .json(); + + // Optionally fetch entry info for display context + if (opkey) + { + this._entry = new Fetcher() + .resource("prj", project, "oplog", oplog, "entries") + .param("opkey", opkey) + .cbo(); + } + + this._build_page(); + } + + // Map CookDependency enum values to display names + _get_dependency_type_name(type_value) + { + const type_names = { + 0: "None", + 1: "File", + 2: "Function", + 3: "TransitiveBuild", + 4: "Package", + 5: "ConsoleVariable", + 6: "Config", + 7: "SettingsObject", + 8: "NativeClass", + 9: "AssetRegistryQuery", + 10: "RedirectionTarget" + }; + return type_names[type_value] || `Unknown (${type_value})`; + } + + // Check if Data content should be expandable + _should_make_expandable(data_string) + { + if (!data_string || data_string.length < 40) + return false; + + // Check if it's JSON array or object + if (!data_string.startsWith('[') && !data_string.startsWith('{')) + return false; + + // Check if formatting would add newlines + try { + const parsed = JSON.parse(data_string); + const formatted = JSON.stringify(parsed, null, 2); + return formatted.includes('\n'); + } catch (e) { + return false; + } + } + + // Get first line of content for collapsed state + _get_first_line(data_string) + { + if (!data_string) + return ""; + + const newline_index = data_string.indexOf('\n'); + if (newline_index === -1) + { + // No newline, truncate if too long + return data_string.length > 80 ? data_string.substring(0, 77) + "..." : data_string; + } + return data_string.substring(0, newline_index) + "..."; + } + + // Format JSON with indentation + _format_json(data_string) + { + try { + const parsed = JSON.parse(data_string); + return JSON.stringify(parsed, null, 2); + } catch (e) { + return data_string; + } + } + + // Toggle expand/collapse state + _toggle_data_cell(cell) + { + const is_expanded = cell.attr("expanded") !== null; + const full_data = cell.attr("data-full"); + + // Find the text wrapper span + const text_wrapper = cell.first_child().next_sibling(); + + if (is_expanded) + { + // Collapse: show first line only + const first_line = this._get_first_line(full_data); + text_wrapper.text(first_line); + cell.attr("expanded", null); + } + else + { + // Expand: show formatted JSON + const formatted = this._format_json(full_data); + text_wrapper.text(formatted); + cell.attr("expanded", ""); + } + } + + // Format dependency data based on its structure + _format_dependency(dep_array) + { + const type = dep_array[0]; + const formatted = {}; + + // Common patterns based on the example data: + // Type 2 (Function): [type, name, array, hash] + // Type 4 (Package): [type, path, hash] + // Type 5 (ConsoleVariable): [type, bool, array, hash] + // Type 8 (NativeClass): [type, path, hash] + // Type 9 (AssetRegistryQuery): [type, bool, object, hash] + // Type 10 (RedirectionTarget): [type, path, hash] + + if (dep_array.length > 1) + { + // Most types have a name/path as second element + if (typeof dep_array[1] === "string") + { + formatted.Name = dep_array[1]; + } + else if (typeof dep_array[1] === "boolean") + { + formatted.Value = dep_array[1].toString(); + } + } + + if (dep_array.length > 2) + { + // Third element varies + if (Array.isArray(dep_array[2])) + { + formatted.Data = JSON.stringify(dep_array[2]); + } + else if (typeof dep_array[2] === "object") + { + formatted.Data = JSON.stringify(dep_array[2]); + } + else if (typeof dep_array[2] === "string") + { + formatted.Hash = dep_array[2]; + } + } + + if (dep_array.length > 3) + { + // Fourth element is usually the hash + if (typeof dep_array[3] === "string") + { + formatted.Hash = dep_array[3]; + } + } + + return formatted; + } + + async _build_page() + { + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); + const artifact_hash = this.get_param("hash"); + + // Build page title + let title = "Cook Artifacts"; + if (this._entry) + { + try + { + const entry = await this._entry; + const entry_obj = entry.as_object().find("entry").as_object(); + const key = entry_obj.find("key").as_value(); + title = `Cook Artifacts`; + } + catch (e) + { + console.error("Failed to fetch entry:", e); + } + } + + const section = this.add_section(title); + + // Fetch and parse artifact + let artifact; + try + { + artifact = await this._artifact; + } + catch (e) + { + section.text(`Failed to load artifact: ${e.message}`); + return; + } + + // Display artifact info + const info_section = section.add_section("Artifact Info"); + const info_table = info_section.add_widget(Table, ["Property", "Value"], Table.Flag_PackRight); + + if (artifact.Version !== undefined) + info_table.add_row("Version", artifact.Version.toString()); + if (artifact.HasSaveResults !== undefined) + info_table.add_row("HasSaveResults", artifact.HasSaveResults.toString()); + if (artifact.PackageSavedHash !== undefined) + info_table.add_row("PackageSavedHash", artifact.PackageSavedHash); + + // Process SaveBuildDependencies + if (artifact.SaveBuildDependencies && artifact.SaveBuildDependencies.Dependencies) + { + this._build_dependency_section( + section, + "Save Build Dependencies", + artifact.SaveBuildDependencies.Dependencies, + artifact.SaveBuildDependencies.StoredKey + ); + } + + // Process LoadBuildDependencies + if (artifact.LoadBuildDependencies && artifact.LoadBuildDependencies.Dependencies) + { + this._build_dependency_section( + section, + "Load Build Dependencies", + artifact.LoadBuildDependencies.Dependencies, + artifact.LoadBuildDependencies.StoredKey + ); + } + + // Process RuntimeDependencies + if (artifact.RuntimeDependencies && artifact.RuntimeDependencies.length > 0) + { + const runtime_section = section.add_section("Runtime Dependencies"); + const runtime_table = runtime_section.add_widget(Table, ["Path"], Table.Flag_PackRight); + for (const dep of artifact.RuntimeDependencies) + { + const row = runtime_table.add_row(dep); + // Make Path clickable to navigate to entry + row.get_cell(0).text(dep).on_click((opkey) => { + window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; + }, dep); + } + } + } + + _build_dependency_section(parent_section, title, dependencies, stored_key) + { + const section = parent_section.add_section(title); + + // Add stored key info + if (stored_key) + { + const key_toolbar = section.add_widget(Toolbar); + key_toolbar.left().add(`Key: ${stored_key}`); + } + + // Group dependencies by type + const dependencies_by_type = {}; + + for (const dep_array of dependencies) + { + if (!Array.isArray(dep_array) || dep_array.length === 0) + continue; + + const type = dep_array[0]; + if (!dependencies_by_type[type]) + dependencies_by_type[type] = []; + + dependencies_by_type[type].push(this._format_dependency(dep_array)); + } + + // Sort types numerically + const sorted_types = Object.keys(dependencies_by_type).map(Number).sort((a, b) => a - b); + + for (const type_value of sorted_types) + { + const type_name = this._get_dependency_type_name(type_value); + const deps = dependencies_by_type[type_value]; + + const type_section = section.add_section(type_name); + + // Determine columns based on available fields + const all_fields = new Set(); + for (const dep of deps) + { + for (const field in dep) + all_fields.add(field); + } + let columns = Array.from(all_fields); + + // Remove Hash column for RedirectionTarget as it's not useful + if (type_value === 10) + { + columns = columns.filter(col => col !== "Hash"); + } + + if (columns.length === 0) + { + type_section.text("No data fields"); + continue; + } + + // Create table with dynamic columns + const table = type_section.add_widget(Table, columns, Table.Flag_PackRight); + + // Check if this type should have clickable Name links + const should_link = (type_value === 3 || type_value === 4 || type_value === 10); + const name_col_index = columns.indexOf("Name"); + + for (const dep of deps) + { + const row_values = columns.map(col => dep[col] || ""); + const row = table.add_row(...row_values); + + // Make Name field clickable for Package, TransitiveBuild, and RedirectionTarget + if (should_link && name_col_index >= 0 && dep.Name) + { + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + row.get_cell(name_col_index).text(dep.Name).on_click((opkey) => { + window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; + }, dep.Name); + } + + // Make Data field expandable/collapsible if needed + const data_col_index = columns.indexOf("Data"); + if (data_col_index >= 0 && dep.Data) + { + const data_cell = row.get_cell(data_col_index); + + if (this._should_make_expandable(dep.Data)) + { + // Store full data in attribute + data_cell.attr("data-full", dep.Data); + + // Clear the cell and rebuild with icon + text + data_cell.inner().innerHTML = ""; + + // Create expand/collapse icon + const icon = data_cell.tag("span").classify("zen_expand_icon").text("+"); + icon.on_click(() => { + this._toggle_data_cell(data_cell); + // Update icon text + const is_expanded = data_cell.attr("expanded") !== null; + icon.text(is_expanded ? "-" : "+"); + }); + + // Add text content wrapper + const text_wrapper = data_cell.tag("span").classify("zen_data_text"); + const first_line = this._get_first_line(dep.Data); + text_wrapper.text(first_line); + + // Store reference to text wrapper for updates + data_cell.attr("data-text-wrapper", "true"); + } + } + } + } + } +} diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 26ea78142..dca3a5c25 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -138,11 +138,23 @@ export class Page extends ZenPage const project = this.get_param("project"); const oplog = this.get_param("oplog"); + const opkey = this.get_param("opkey"); const link = row.get_cell(0).link( - "/" + ["prj", project, "oplog", oplog, value+".json"].join("/") + (key === "cook.artifacts") ? + `?page=cookartifacts&project=${project}&oplog=${oplog}&opkey=${opkey}&hash=${value}` + : "/" + ["prj", project, "oplog", oplog, value+".json"].join("/") ); const action_tb = new Toolbar(row.get_cell(-1), true); + + // Add "view-raw" button for cook.artifacts + if (key === "cook.artifacts") + { + action_tb.left().add("view-raw").on_click(() => { + window.location = "/" + ["prj", project, "oplog", oplog, value+".json"].join("/"); + }); + } + action_tb.left().add("copy-hash").on_click(async (v) => { await navigator.clipboard.writeText(v); }, value); diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 2f9643008..3ec0248cb 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -118,13 +118,22 @@ export class ZenPage extends PageBase var oplog = this.get_param("oplog"); if (oplog != undefined) { + new_crumb(auto_name, `?page=project&project=${project}`); auto_name = oplog; - new_crumb(project, `?page=project&project=${project}`); var opkey = this.get_param("opkey") if (opkey != undefined) { - auto_name = opkey.split("/").pop().split("\\").pop();; - new_crumb(oplog, `?page=oplog&project=${project}&oplog=${oplog}`); + new_crumb(auto_name, `?page=oplog&project=${project}&oplog=${oplog}`); + auto_name = opkey.split("/").pop().split("\\").pop(); + + // Check if we're viewing cook artifacts + var page = this.get_param("page"); + var hash = this.get_param("hash"); + if (hash != undefined && page == "cookartifacts") + { + new_crumb(auto_name, `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey}`); + auto_name = "cook artifacts"; + } } } } diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index cc53c0519..34c265610 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -172,6 +172,24 @@ a { } } +/* expandable cell ---------------------------------------------------------- */ + +.zen_expand_icon { + cursor: pointer; + margin-right: 0.5em; + color: var(--theme_g1); + font-weight: bold; + user-select: none; +} + +.zen_expand_icon:hover { + color: var(--theme_ln); +} + +.zen_data_text { + user-select: text; +} + /* toolbar ------------------------------------------------------------------ */ .zen_toolbar { -- cgit v1.2.3 From 2159b2ce105935ce4d52a726094f9bbb91537d0c Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 17 Feb 2026 13:56:33 +0100 Subject: misc fixes brought over from sb/proto (#759) * `RwLock::WithSharedLock` and `RwLock::WithExclusiveLock` can now return a value (which is returned by the passed function) * Comma-separated logger specification now correctly deals with commas * `GetSystemMetrics` properly accounts for cores * cpr response formatter passes arguments in the right order * `HttpServerRequest::SetLogRequest` can be used to selectively log HTTP requests --- src/zencore/include/zencore/thread.h | 8 ++++---- src/zencore/logging.cpp | 2 +- src/zencore/system.cpp | 12 ++++++------ src/zenhttp/include/zenhttp/cprutils.h | 4 ++-- src/zenhttp/include/zenhttp/httpserver.h | 17 +++++++++++++++-- src/zenhttp/servers/httpasio.cpp | 7 ++++++- src/zenhttp/servers/httpsys.cpp | 30 +++++++++++++++--------------- 7 files changed, 49 insertions(+), 31 deletions(-) (limited to 'src') diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h index de8f9399c..a1c68b0b2 100644 --- a/src/zencore/include/zencore/thread.h +++ b/src/zencore/include/zencore/thread.h @@ -61,10 +61,10 @@ public: RwLock* m_Lock; }; - inline void WithSharedLock(auto&& Fun) + inline auto WithSharedLock(auto&& Fun) { SharedLockScope $(*this); - Fun(); + return Fun(); } struct ExclusiveLockScope @@ -85,10 +85,10 @@ public: RwLock* m_Lock; }; - inline void WithExclusiveLock(auto&& Fun) + inline auto WithExclusiveLock(auto&& Fun) { ExclusiveLockScope $(*this); - Fun(); + return Fun(); } private: diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index 77e05a909..e79c4b41c 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -251,7 +251,7 @@ RefreshLogLevels(level::LogLevel* DefaultLevel) if (auto CommaPos = Spec.find_first_of(','); CommaPos != std::string_view::npos) { - LoggerName = Spec.substr(CommaPos + 1); + LoggerName = Spec.substr(0, CommaPos); Spec.remove_prefix(CommaPos + 1); } else diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index b9ac3bdee..e92691781 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -66,15 +66,15 @@ GetSystemMetrics() // Determine physical core count DWORD BufferSize = 0; - BOOL Result = GetLogicalProcessorInformation(nullptr, &BufferSize); + BOOL Result = GetLogicalProcessorInformationEx(RelationAll, nullptr, &BufferSize); if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER) { ThrowSystemError(Error, "Failed to get buffer size for logical processor information"); } - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)Memory::Alloc(BufferSize); + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)Memory::Alloc(BufferSize); - Result = GetLogicalProcessorInformation(Buffer, &BufferSize); + Result = GetLogicalProcessorInformationEx(RelationAll, Buffer, &BufferSize); if (!Result) { Memory::Free(Buffer); @@ -84,9 +84,9 @@ GetSystemMetrics() DWORD ProcessorPkgCount = 0; DWORD ProcessorCoreCount = 0; DWORD ByteOffset = 0; - while (ByteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION) <= BufferSize) + while (ByteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX) <= BufferSize) { - const SYSTEM_LOGICAL_PROCESSOR_INFORMATION& Slpi = Buffer[ByteOffset / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION)]; + const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX& Slpi = Buffer[ByteOffset / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)]; if (Slpi.Relationship == RelationProcessorCore) { ProcessorCoreCount++; @@ -95,7 +95,7 @@ GetSystemMetrics() { ProcessorPkgCount++; } - ByteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION); + ByteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX); } Metrics.CoreCount = ProcessorCoreCount; diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index a988346e0..c252a5d99 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -66,10 +66,10 @@ struct fmt::formatter Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Json); } else @@ -82,10 +82,10 @@ struct fmt::formatter Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Body.GetText()); } } diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 60f6bc9f2..cbac06cb6 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -39,7 +39,7 @@ public: // Synchronous operations [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } [[nodiscard]] inline HttpService& Service() const { return m_Service; } @@ -81,6 +81,18 @@ public: inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + inline void SetLogRequest(bool ShouldLog) + { + if (ShouldLog) + { + m_Flags |= kLogRequest; + } + else + { + m_Flags &= ~kLogRequest; + } + } + inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); } /** Read POST/PUT payload for request body, which is always available without delay */ @@ -119,6 +131,7 @@ protected: kSuppressBody = 1 << 1, kHaveRequestId = 1 << 2, kHaveSessionId = 1 << 3, + kLogRequest = 1 << 4, }; mutable uint32_t m_Flags = 0; @@ -254,7 +267,7 @@ public: inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} ~HttpRouterRequest() = default; HttpRouterRequest(const HttpRouterRequest&) = delete; diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 230aac6a8..1f42b05d2 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -147,7 +147,7 @@ inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); + // Logger.SetLogLevel(logging::level::Trace); return Logger; } @@ -1264,6 +1264,11 @@ HttpServerConnection::HandleRequest() if (std::unique_ptr Response = std::move(Request.m_Response)) { + if (Request.ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength())); + } + // Transmit the response if (m_RequestData.RequestVerb() == HttpVerb::kHead) diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4df4cd079..5fed94f1c 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -702,21 +702,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) HTTP_CACHE_POLICY CachePolicy; - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags + &HttpResponse, // HttpResponse + &CachePolicy, // CachePolicy + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); m_IsInitialResponse = false; } @@ -724,9 +725,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { // Subsequent response API calls - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent @@ -1351,7 +1352,6 @@ HttpSysServer::OnRun(bool IsInteractive) bool ShutdownRequested = false; do { - // int WaitTimeout = -1; int WaitTimeout = 100; if (IsInteractive) -- cgit v1.2.3 From 5e1e23e209eec75a396c18f8eee3d93a9e196bfc Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Tue, 17 Feb 2026 14:00:53 +0100 Subject: add http server root password protection (#757) - Feature: Added `--security-config-path` option to zenserver to configure security settings - Expects a path to a .json file - Default is an empty path resulting in no extra security settings and legacy behavior - Current support is a top level filter of incoming http requests restricted to the `password` type - `password` type will check the `Authorization` header and match it to the selected authorization strategy - Currently the security settings is very basic and configured to a fixed username+password at startup { "http" { "root": { "filter": { "type": "password", "config": { "password": { "username": "", "password": "" }, "protect-machine-local-requests": false, "unprotected-uris": [ "/health/", "/health/info", "/health/version" ] } } } } } --- src/zencore/include/zencore/string.h | 16 ++ src/zenhttp/httpclient.cpp | 91 ++++++++++++ src/zenhttp/httpserver.cpp | 3 +- src/zenhttp/include/zenhttp/httpserver.h | 7 +- .../include/zenhttp/security/passwordsecurity.h | 38 ++--- .../zenhttp/security/passwordsecurityfilter.h | 51 +++++++ src/zenhttp/security/passwordsecurity.cpp | 164 +++++++-------------- src/zenhttp/security/passwordsecurityfilter.cpp | 56 +++++++ src/zenhttp/servers/httpasio.cpp | 14 +- src/zenhttp/servers/httpmulti.cpp | 1 + src/zenhttp/servers/httpnull.cpp | 1 + src/zenhttp/servers/httpparser.cpp | 6 + src/zenhttp/servers/httpparser.h | 3 + src/zenhttp/servers/httpplugin.cpp | 18 ++- src/zenhttp/servers/httpsys.cpp | 21 ++- src/zenserver/config/config.cpp | 20 ++- src/zenserver/config/config.h | 13 +- src/zenserver/zenserver.cpp | 50 ++++++- src/zenserver/zenserver.h | 8 +- 19 files changed, 415 insertions(+), 166 deletions(-) create mode 100644 src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h create mode 100644 src/zenhttp/security/passwordsecurityfilter.cpp (limited to 'src') diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index cbff6454f..5a12ba5d2 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -796,6 +796,22 @@ HashStringDjb2(const std::string_view& InString) return HashValue; } +constexpr uint32_t +HashStringDjb2(const std::span InStrings) +{ + uint32_t HashValue = 5381; + + for (const std::string_view& String : InStrings) + { + for (int CurChar : String) + { + HashValue = HashValue * 33 + CurChar; + } + } + + return HashValue; +} + constexpr uint32_t HashStringAsLowerDjb2(const std::string_view& InString) { diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 16729ce38..d3b59df2b 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -25,6 +25,7 @@ # include # include # include +# include # include "servers/httpasio.h" # include "servers/httpsys.h" @@ -662,6 +663,96 @@ TEST_CASE("httpclient.requestfilter") } } +TEST_CASE("httpclient.password") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + SUBCASE("usernamepassword") + { + CbObjectWriter Writer; + { + Writer.BeginObject("basic"); + { + Writer << "username"sv + << "me"; + Writer << "password"sv + << "456123789"; + } + Writer.EndObject(); + Writer << "protect-machine-local-requests" << true; + } + + PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save()); + + PasswordHttpFilter MyFilter(PasswordFilterOptions); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response ForbiddenResponse = Client.Get("/test/yo"); + CHECK(!ForbiddenResponse.IsSuccess()); + CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden); + + HttpClient::Response WithBasicResponse = + Client.Get("/test/yo", + std::pair("Authorization", + fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password))); + CHECK(WithBasicResponse.IsSuccess()); + AsioServer->SetHttpRequestFilter(nullptr); + } + AsioServer->RequestExit(); + } +} void httpclient_forcelink() { diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index d8367fcb2..f2fe4738f 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -1317,7 +1317,8 @@ TEST_CASE("http.common") TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } - virtual bool IsLocalMachineRequest() const override { return false; } + virtual bool IsLocalMachineRequest() const override { return false; } + virtual std::string_view GetAuthorizationHeader() const override { return {}; } virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override { diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index cbac06cb6..350532126 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -101,7 +101,8 @@ public: CbObject ReadPayloadObject(); CbPackage ReadPayloadPackage(); - virtual bool IsLocalMachineRequest() const = 0; + virtual bool IsLocalMachineRequest() const = 0; + virtual std::string_view GetAuthorizationHeader() const = 0; /** Respond with payload @@ -162,8 +163,10 @@ public: virtual void OnRequestComplete() = 0; }; -struct IHttpRequestFilter +class IHttpRequestFilter { +public: + virtual ~IHttpRequestFilter() {} enum class Result { Forbidden, diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h index 026c2865b..6b2b548a6 100644 --- a/src/zenhttp/include/zenhttp/security/passwordsecurity.h +++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h @@ -10,43 +10,29 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -struct PasswordSecurityConfiguration -{ - std::string Password; // "password" - bool ProtectMachineLocalRequests = false; // "protect-machine-local-requests" - std::vector UnprotectedUris; // "unprotected-urls" -}; - class PasswordSecurity { public: - PasswordSecurity(const PasswordSecurityConfiguration& Config); + struct Configuration + { + std::string Password; + bool ProtectMachineLocalRequests = false; + std::vector UnprotectedUris; + }; + + explicit PasswordSecurity(const Configuration& Config); [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; } [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; } - [[nodiscard]] bool IsUnprotectedUri(std::string_view Uri) const; + [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const; - bool IsAllowed(std::string_view Password, std::string_view Uri, bool IsMachineLocalRequest); + bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest); private: - const PasswordSecurityConfiguration m_Config; - tsl::robin_map m_UnprotectedUrlHashes; + const Configuration m_Config; + tsl::robin_map m_UnprotectedUriHashes; }; -/** - * Expected format (Json) - * { - * "password\": \"1234\", - * "protect-machine-local-requests\": false, - * "unprotected-urls\": [ - * "/health\", - * "/health/info\", - * "/health/version\" - * ] - * } - */ -PasswordSecurityConfiguration ReadPasswordSecurityConfiguration(CbObjectView ConfigObject); - void passwordsecurity_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h new file mode 100644 index 000000000..c098f05ad --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +namespace zen { + +class PasswordHttpFilter : public IHttpRequestFilter +{ +public: + static constexpr std::string_view TypeName = "password"; + + struct Configuration + { + PasswordSecurity::Configuration PasswordConfig; + std::string AuthenticationTypeString; + }; + + /** + * Expected format (Json) + * { + * "password": { # "Authorization: Basic " style + * "username": "", + * "password": "" + * }, + * "protect-machine-local-requests": false, + * "unprotected-uris": [ + * "/health/", + * "/health/info", + * "/health/version" + * ] + * } + */ + static Configuration ReadConfiguration(CbObjectView Config); + + explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config) + : m_PasswordSecurity(Config.PasswordConfig) + , m_AuthenticationTypeString(Config.AuthenticationTypeString) + { + } + + virtual Result FilterRequest(HttpServerRequest& Request) override; + +private: + PasswordSecurity m_PasswordSecurity; + const std::string m_AuthenticationTypeString; +}; + +} // namespace zen diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp index 37be9a018..a8fb9c3f5 100644 --- a/src/zenhttp/security/passwordsecurity.cpp +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -13,13 +13,13 @@ namespace zen { using namespace std::literals; -PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) : m_Config(Config) +PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config) { - m_UnprotectedUrlHashes.reserve(m_Config.UnprotectedUris.size()); + m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size()); for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) { const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; - if (auto Result = m_UnprotectedUrlHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) + if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) { throw std::runtime_error(fmt::format( "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", @@ -32,35 +32,30 @@ PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) } bool -PasswordSecurity::IsUnprotectedUri(std::string_view Uri) const +PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const { if (!m_Config.UnprotectedUris.empty()) { - uint32_t UriHash = HashStringDjb2(Uri); - if (auto It = m_UnprotectedUrlHashes.find(UriHash); It != m_UnprotectedUrlHashes.end()) + uint32_t UriHash = HashStringDjb2(std::array{BaseUri, RelativeUri}); + if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end()) { - if (m_Config.UnprotectedUris[It->second] == Uri) + const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second]; + if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length()) { - return true; + if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri) + { + return true; + } } } } return false; } -PasswordSecurityConfiguration -ReadPasswordSecurityConfiguration(CbObjectView ConfigObject) -{ - return PasswordSecurityConfiguration{ - .Password = std::string(ConfigObject["password"sv].AsString()), - .ProtectMachineLocalRequests = ConfigObject["protect-machine-local-requests"sv].AsBool(), - .UnprotectedUris = compactbinary_helpers::ReadArray("unprotected-urls"sv, ConfigObject)}; -} - bool -PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, bool IsMachineLocalRequest) +PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest) { - if (IsUnprotectedUri(Uri)) + if (IsUnprotectedUri(BaseUri, RelativeUri)) { return true; } @@ -81,119 +76,74 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, b #if ZEN_WITH_TESTS -TEST_CASE("passwordsecurity.readconfig") -{ - auto ReadConfigJson = [](std::string_view Json) { - std::string JsonError; - CbObject Config = LoadCompactBinaryFromJson(Json, JsonError).AsObject(); - REQUIRE(JsonError.empty()); - return Config; - }; - - { - PasswordSecurityConfiguration EmptyConfig = ReadPasswordSecurityConfiguration(CbObject()); - CHECK(EmptyConfig.Password.empty()); - CHECK(!EmptyConfig.ProtectMachineLocalRequests); - CHECK(EmptyConfig.UnprotectedUris.empty()); - } - - { - const std::string_view SimpleConfigJson = - "{\n" - " \"password\": \"1234\"\n" - "}"; - PasswordSecurityConfiguration SimpleConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(SimpleConfigJson)); - CHECK(SimpleConfig.Password == "1234"); - CHECK(!SimpleConfig.ProtectMachineLocalRequests); - CHECK(SimpleConfig.UnprotectedUris.empty()); - } - - { - const std::string_view ComplexConfigJson = - "{\n" - " \"password\": \"1234\",\n" - " \"protect-machine-local-requests\": true,\n" - " \"unprotected-urls\": [\n" - " \"/health\",\n" - " \"/health/info\",\n" - " \"/health/version\"\n" - " ]\n" - "}"; - PasswordSecurityConfiguration ComplexConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(ComplexConfigJson)); - CHECK(ComplexConfig.Password == "1234"); - CHECK(ComplexConfig.ProtectMachineLocalRequests); - CHECK(ComplexConfig.UnprotectedUris == std::vector({"/health", "/health/info", "/health/version"})); - } -} - TEST_CASE("passwordsecurity.allowanything") { PasswordSecurity Anything({}); - CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowalllocal") { PasswordSecurity AllLocal({.Password = "123456"}); - CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowonlypassword") { PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomeexternaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); - CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomelocaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector({"/free/access", "/ok"})}); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.conflictingunprotecteduris") diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp new file mode 100644 index 000000000..87d8cc275 --- /dev/null +++ b/src/zenhttp/security/passwordsecurityfilter.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurityfilter.h" + +#include +#include +#include + +namespace zen { + +using namespace std::literals; + +PasswordHttpFilter::Configuration +PasswordHttpFilter::ReadConfiguration(CbObjectView Config) +{ + Configuration Result; + if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType) + { + Result.AuthenticationTypeString = "Basic "; + std::string_view Username = PasswordType["username"sv].AsString(); + std::string_view Password = PasswordType["password"sv].AsString(); + std::string UsernamePassword = fmt::format("{}:{}", Username, Password); + Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length()))); + Base64::Encode(reinterpret_cast(UsernamePassword.data()), + uint32_t(UsernamePassword.size()), + const_cast(Result.PasswordConfig.Password.data())); + } + Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool(); + Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray("unprotected-uris"sv, Config); + return Result; +} + +IHttpRequestFilter::Result +PasswordHttpFilter::FilterRequest(HttpServerRequest& Request) +{ + std::string_view Password; + std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); + size_t AuthorizationHeaderLength = AuthorizationHeader.length(); + if (AuthorizationHeaderLength > m_AuthenticationTypeString.length()) + { + if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) + { + Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + } + } + + bool IsAllowed = + m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest()); + if (IsAllowed) + { + return Result::Accepted; + } + return Result::Forbidden; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 1f42b05d2..1c0ebef90 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -542,7 +542,8 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -1747,6 +1748,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const return m_IsLocalMachineRequest; } +std::string_view +HttpAsioServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + IoBuffer HttpAsioServerRequest::ReadPayload() { @@ -1964,8 +1971,8 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + + return RequestFilter->FilterRequest(Request); } } // namespace zen::asio_http @@ -2080,6 +2087,7 @@ HttpAsioServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 850d7d6b9..310ac9dc0 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -82,6 +82,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index db360c5fb..9bb7ef3bc 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -57,6 +57,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 93094e21b..be5befcd2 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -19,6 +19,7 @@ static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); ////////////////////////////////////////////////////////////////////////// // @@ -154,6 +155,10 @@ HttpRequestParser::ParseCurrentHeader() { m_ContentTypeHeaderIndex = CurrentHeaderIndex; } + else if (HeaderHash == HashAuthorization) + { + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + } else if (HeaderHash == HashSession) { m_SessionId = Oid::TryFromHexString(HeaderValue); @@ -357,6 +362,7 @@ HttpRequestParser::ResetState() m_AcceptHeaderIndex = -1; m_ContentTypeHeaderIndex = -1; m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 0d2664ec5..ff56ca970 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -46,6 +46,8 @@ struct HttpRequestParser std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); } + std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + private: struct HeaderRange { @@ -83,6 +85,7 @@ private: int8_t m_AcceptHeaderIndex; int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; + int8_t m_AuthorizationHeaderIndex; HttpVerb m_RequestVerb; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 4219dc292..8564826d6 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -147,10 +147,10 @@ public: HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection - virtual bool IsLocalMachineRequest() const /* override*/ { return false; } - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual std::string_view GetAuthorizationHeader() const override; + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -636,6 +636,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest() { } +std::string_view +HttpPluginServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + Oid HttpPluginServerRequest::ParseSessionId() const { @@ -831,6 +837,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -932,8 +939,7 @@ HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + return RequestFilter->FilterRequest(Request); } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 5fed94f1c..14896c803 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -72,6 +72,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In OutString.Append("unknown"); } +class HttpSysServerRequest; + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -102,7 +104,7 @@ public: inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } - IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); private: int InitializeServer(int BasePort); @@ -319,7 +321,8 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const; + virtual std::string_view GetAuthorizationHeader() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -1364,6 +1367,7 @@ HttpSysServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -1861,6 +1865,14 @@ HttpSysServerRequest::IsLocalMachineRequest() const } } +std::string_view +HttpSysServerRequest::GetAuthorizationHeader() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; + return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); +} + IoBuffer HttpSysServerRequest::ReadPayload() { @@ -2270,7 +2282,7 @@ HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) } IHttpRequestFilter::Result -HttpSysServer::FilterRequest(HttpServerRequest& Request) +HttpSysServer::FilterRequest(HttpSysServerRequest& Request) { if (!m_HttpRequestFilter.load()) { @@ -2282,8 +2294,7 @@ HttpSysServer::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + return RequestFilter->FilterRequest(Request); } Ref diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index 2b77df642..e36352dae 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -140,6 +140,7 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions LuaOptions.AddOption("server.contentdir"sv, ServerOptions.ContentDir, "content-dir"sv); LuaOptions.AddOption("server.debug"sv, ServerOptions.IsDebug, "debug"sv); LuaOptions.AddOption("server.clean"sv, ServerOptions.IsCleanStart, "clean"sv); + LuaOptions.AddOption("server.security.configpath"sv, ServerOptions.SecurityConfigPath, "security-config-path"sv); ////// network @@ -186,6 +187,7 @@ struct ZenServerCmdLineOptions std::string ContentDir; std::string DataDir; std::string BaseSnapshotDir; + std::string SecurityConfigPath; ZenLoggingCmdLineOptions LoggingOptions; @@ -300,6 +302,13 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi cxxopts::value(ServerOptions.HttpConfig.ForceLoopback)->default_value("false"), ""); + options.add_option("network", + "", + "security-config-path", + "Path to http security configuration file", + cxxopts::value(SecurityConfigPath), + ""); + #if ZEN_WITH_HTTPSYS options.add_option("httpsys", "", @@ -380,11 +389,12 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig throw std::runtime_error(fmt::format("'--snapshot-dir' ('{}') must be a directory", ServerOptions.BaseSnapshotDir)); } - ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); - ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); - ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); - ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); - ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + ServerOptions.SystemRootDir = MakeSafeAbsolutePath(SystemRootDir); + ServerOptions.DataDir = MakeSafeAbsolutePath(DataDir); + ServerOptions.ContentDir = MakeSafeAbsolutePath(ContentDir); + ServerOptions.ConfigFile = MakeSafeAbsolutePath(ConfigFile); + ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); + ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig); } diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h index 32c22cb05..55aee07f9 100644 --- a/src/zenserver/config/config.h +++ b/src/zenserver/config/config.h @@ -56,12 +56,13 @@ struct ZenServerConfig bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements bool ShouldCrash = false; // Option for testing crash handling bool IsFirstRun = false; - std::filesystem::path ConfigFile; // Path to Lua config file - std::filesystem::path SystemRootDir; // System root directory (used for machine level config) - std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) - std::filesystem::path DataDir; // Root directory for state (used for testing) - std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start) - std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::filesystem::path ConfigFile; // Path to Lua config file + std::filesystem::path SystemRootDir; // System root directory (used for machine level config) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path BaseSnapshotDir; // Path to server state snapshot (will be copied into data dir on start) + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::filesystem::path SecurityConfigPath; // Path to a Json security configuration file #if ZEN_WITH_TRACE bool HasTraceCommandlineOptions = false; diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index d54357368..7f9bf56a9 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -142,6 +143,8 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: ZEN_INFO("Effective concurrency: {} (hw: {})", GetHardwareConcurrency(), std::thread::hardware_concurrency()); + InitializeSecuritySettings(ServerOptions); + m_StatusService.RegisterHandler("status", *this); m_Http->RegisterService(m_StatusService); @@ -386,10 +389,10 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) { // clang-format off std::list> Settings = { - {"DataDir"sv, ServerConfig.DataDir.string()}, - {"AbsLogFile"sv, ServerConfig.LoggingConfig.AbsLogFile.string()}, - {"SystemRootDir"sv, ServerConfig.SystemRootDir.string()}, - {"ContentDir"sv, ServerConfig.ContentDir.string()}, + {"DataDir"sv, fmt::format("{}", ServerConfig.DataDir)}, + {"AbsLogFile"sv, fmt::format("{}", ServerConfig.LoggingConfig.AbsLogFile)}, + {"SystemRootDir"sv, fmt::format("{}", ServerConfig.SystemRootDir)}, + {"ContentDir"sv, fmt::format("{}", ServerConfig.ContentDir)}, {"BasePort"sv, fmt::to_string(ServerConfig.BasePort)}, {"IsDebug"sv, fmt::to_string(ServerConfig.IsDebug)}, {"IsCleanStart"sv, fmt::to_string(ServerConfig.IsCleanStart)}, @@ -406,6 +409,7 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) {"Sentry DSN"sv, ServerConfig.SentryConfig.Dsn.empty() ? "not set" : ServerConfig.SentryConfig.Dsn}, {"Sentry Environment"sv, ServerConfig.SentryConfig.Environment}, {"Statsd Enabled"sv, fmt::to_string(ServerConfig.StatsConfig.Enabled)}, + {"SecurityConfigPath"sv, fmt::format("{}", ServerConfig.SecurityConfigPath)}, }; // clang-format on @@ -432,6 +436,44 @@ ZenServerBase::LogSettingsSummary(const ZenServerConfig& ServerConfig) } } +void +ZenServerBase::InitializeSecuritySettings(const ZenServerConfig& ServerOptions) +{ + ZEN_ASSERT(m_Http); + + if (!ServerOptions.SecurityConfigPath.empty()) + { + IoBuffer SecurityJson = ReadFile(ServerOptions.SecurityConfigPath).Flatten(); + std::string_view Json(reinterpret_cast(SecurityJson.GetData()), SecurityJson.GetSize()); + std::string JsonError; + CbObject SecurityConfig = LoadCompactBinaryFromJson(Json, JsonError).AsObject(); + if (!JsonError.empty()) + { + throw std::runtime_error( + fmt::format("Invalid security configuration file at {}. '{}'", ServerOptions.SecurityConfigPath, JsonError)); + } + + CbObjectView HttpRootFilterConfig = SecurityConfig["http"sv].AsObjectView()["root"sv].AsObjectView()["filter"sv].AsObjectView(); + if (HttpRootFilterConfig) + { + std::string_view FilterType = HttpRootFilterConfig["type"sv].AsString(); + if (FilterType == PasswordHttpFilter::TypeName) + { + PasswordHttpFilter::Configuration Config = + PasswordHttpFilter::ReadConfiguration(HttpRootFilterConfig["config"].AsObjectView()); + m_HttpRequestFilter = std::make_unique(Config); + m_Http->SetHttpRequestFilter(m_HttpRequestFilter.get()); + } + else + { + throw std::runtime_error(fmt::format("Security configuration file at {} references unknown http root filter type '{}'", + ServerOptions.SecurityConfigPath, + FilterType)); + } + } + } +} + ////////////////////////////////////////////////////////////////////////// ZenServerMain::ZenServerMain(ZenServerConfig& ServerOptions) : m_ServerOptions(ServerOptions) diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index ab7122fcc..efa46f361 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -72,7 +72,10 @@ protected: std::function m_IsReadyFunc; void OnReady(); - Ref m_Http; + Ref m_Http; + + std::unique_ptr m_HttpRequestFilter; + HttpHealthService m_HealthService; HttpStatusService m_StatusService; @@ -107,6 +110,9 @@ protected: // IHttpStatusProvider virtual void HandleStatusRequest(HttpServerRequest& Request) override; + +private: + void InitializeSecuritySettings(const ZenServerConfig& ServerOptions); }; class ZenServerMain -- cgit v1.2.3 From d1324d607e54e2e97d666a2d1ece9ac9495d1eb1 Mon Sep 17 00:00:00 2001 From: zousar Date: Tue, 17 Feb 2026 20:21:26 -0700 Subject: Make files table in entry.js paginated and searchable --- src/zenserver/frontend/html/pages/entry.js | 210 +++++++++++++++++++++++------ 1 file changed, 170 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index dca3a5c25..13d5e44e7 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -26,6 +26,9 @@ export class Page extends ZenPage this._indexer = this.load_indexer(project, oplog); + this._files_index_start = Number(this.get_param("files_start", 0)) || 0; + this._files_index_count = Number(this.get_param("files_count", 50)) || 0; + this._build_page(); } @@ -253,20 +256,13 @@ export class Page extends ZenPage // files var has_file_data = false; { - const sub_section = section.add_section("files"); - const table = sub_section.add_widget( - Table, - ["name", "actions"], Table.Flag_PackRight - ); - table.id("filetable"); - for (const field_name of ["files"]) + var file_data = entry.find("files"); + if (file_data != undefined) { - var file_data = entry.find(field_name); - if (file_data == undefined) - continue; - has_file_data = true; + // Extract files into array + this._files_data = []; for (const item of file_data.as_array()) { var io_hash = undefined, cid = undefined, server_path = undefined, client_path = undefined; @@ -294,37 +290,26 @@ export class Page extends ZenPage cid = ret; } - const row = table.add_row(server_path); + this._files_data.push({ + server_path: server_path, + client_path: client_path, + io_hash: io_hash, + cid: cid + }); + } - var base_name = server_path.split("/").pop().split("\\").pop(); - const project = this.get_param("project"); - const oplog = this.get_param("oplog"); - if (this._is_null_io_hash_string(io_hash)) - { - const link = row.get_cell(0).link( - "/" + ["prj", project, "oplog", oplog, cid].join("/") - ); - link.first_child().attr("download", `${cid}_${base_name}`); - - const action_tb = new Toolbar(row.get_cell(-1), true); - action_tb.left().add("copy-id").on_click(async (v) => { - await navigator.clipboard.writeText(v); - }, cid); - } - else - { - const link = row.get_cell(0).link( - "/" + ["prj", project, "oplog", oplog, io_hash].join("/") - ); - link.first_child().attr("download", `${io_hash}_${base_name}`); - - const action_tb = new Toolbar(row.get_cell(-1), true); - action_tb.left().add("copy-hash").on_click(async (v) => { - await navigator.clipboard.writeText(v); - }, io_hash); - } + this._files_index_max = this._files_data.length; - } + const sub_section = section.add_section("files"); + this._build_files_nav(sub_section); + + this._files_table = sub_section.add_widget( + Table, + ["name", "actions"], Table.Flag_PackRight + ); + this._files_table.id("filetable"); + + this._build_files_table(this._files_index_start); } } @@ -419,4 +404,149 @@ export class Page extends ZenPage params.set("opkey", opkey); window.location.search = params; } + + _build_files_nav(section) + { + const nav = section.add_widget(Toolbar); + const left = nav.left(); + left.add("|<") .on_click(() => this._on_files_next_prev(-10e10)); + left.add("<<") .on_click(() => this._on_files_next_prev(-10)); + left.add("prev").on_click(() => this._on_files_next_prev( -1)); + left.add("next").on_click(() => this._on_files_next_prev( 1)); + left.add(">>") .on_click(() => this._on_files_next_prev( 10)); + left.add(">|") .on_click(() => this._on_files_next_prev( 10e10)); + + left.sep(); + for (var count of [10, 25, 50, 100]) + { + var handler = (n) => this._on_files_change_count(n); + left.add(count).on_click(handler, count); + } + + const right = nav.right(); + right.add(Friendly.sep(this._files_index_max)); + + right.sep(); + var search_input = right.add("search:", "label").tag("input"); + search_input.on("change", (x) => this._search_files(x.inner().value), search_input); + } + + _build_files_table(index) + { + this._files_index_count = Math.max(this._files_index_count, 1); + index = Math.min(index, this._files_index_max - this._files_index_count); + index = Math.max(index, 0); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + + const end_index = Math.min(index + this._files_index_count, this._files_index_max); + + this._files_table.clear(index); + for (var i = index; i < end_index; i++) + { + const file_item = this._files_data[i]; + const row = this._files_table.add_row(file_item.server_path); + + var base_name = file_item.server_path.split("/").pop().split("\\").pop(); + if (this._is_null_io_hash_string(file_item.io_hash)) + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/") + ); + link.first_child().attr("download", `${file_item.cid}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-id").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.cid); + } + else + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/") + ); + link.first_child().attr("download", `${file_item.io_hash}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-hash").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.io_hash); + } + } + + this.set_param("files_start", index); + this.set_param("files_count", this._files_index_count); + this._files_index_start = index; + } + + _on_files_change_count(value) + { + this._files_index_count = parseInt(value); + this._build_files_table(this._files_index_start); + } + + _on_files_next_prev(direction) + { + var index = this._files_index_start + (this._files_index_count * direction); + index = Math.max(0, index); + this._build_files_table(index); + } + + _search_files(needle) + { + if (needle.length == 0) + { + this._build_files_table(this._files_index_start); + return; + } + needle = needle.trim().toLowerCase(); + + this._files_table.clear(this._files_index_start); + + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); + + var added = 0; + const truncate_at = this.get_param("searchmax") || 250; + for (const file_item of this._files_data) + { + if (!file_item.server_path.toLowerCase().includes(needle)) + continue; + + const row = this._files_table.add_row(file_item.server_path); + + var base_name = file_item.server_path.split("/").pop().split("\\").pop(); + if (this._is_null_io_hash_string(file_item.io_hash)) + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.cid].join("/") + ); + link.first_child().attr("download", `${file_item.cid}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-id").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.cid); + } + else + { + const link = row.get_cell(0).link( + "/" + ["prj", project, "oplog", oplog, file_item.io_hash].join("/") + ); + link.first_child().attr("download", `${file_item.io_hash}_${base_name}`); + + const action_tb = new Toolbar(row.get_cell(-1), true); + action_tb.left().add("copy-hash").on_click(async (v) => { + await navigator.clipboard.writeText(v); + }, file_item.io_hash); + } + + if (++added >= truncate_at) + { + this._files_table.add_row("...truncated"); + break; + } + } + } } -- cgit v1.2.3 From 1c8948411e68429f613889c7e278bb0422c172a7 Mon Sep 17 00:00:00 2001 From: zousar Date: Tue, 17 Feb 2026 20:46:45 -0700 Subject: Rename the cache section in the web ui --- src/zenserver/frontend/html/pages/start.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/start.js b/src/zenserver/frontend/html/pages/start.js index 4c8789431..2cf12bf12 100644 --- a/src/zenserver/frontend/html/pages/start.js +++ b/src/zenserver/frontend/html/pages/start.js @@ -46,7 +46,7 @@ export class Page extends ZenPage } // cache - var section = this.add_section("z$"); + var section = this.add_section("cache"); section.tag().classify("dropall").text("drop-all").on_click(() => this.drop_all("z$")); -- cgit v1.2.3 From fbd53c5500d4898be9e2c76646f220dd88a96f36 Mon Sep 17 00:00:00 2001 From: zousar Date: Tue, 17 Feb 2026 21:16:38 -0700 Subject: Dependencies table doesn't reflow the entries page --- src/zenserver/frontend/html/pages/entry.js | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index 13d5e44e7..c4746bf52 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -43,25 +43,39 @@ export class Page extends ZenPage return indexer; } - async _build_deps(section, tree) + _build_deps(section, tree) { - const indexer = await this._indexer; + const project = this.get_param("project"); + const oplog = this.get_param("oplog"); for (const dep_name in tree) { const dep_section = section.add_section(dep_name); const table = dep_section.add_widget(Table, ["name", "id"], Table.Flag_PackRight); + for (const dep_id of tree[dep_name]) { - const cell_values = ["", dep_id.toString(16).padStart(16, "0")]; + const hex_id = dep_id.toString(16).padStart(16, "0"); + const cell_values = ["loading...", hex_id]; const row = table.add_row(...cell_values); - var opkey = indexer.lookup_id(dep_id); - row.get_cell(0).text(opkey).on_click((k) => this.view_opkey(k), opkey); + // Asynchronously resolve the name + this._resolve_dep_name(row.get_cell(0), dep_id, project, oplog); } } } + async _resolve_dep_name(cell, dep_id, project, oplog) + { + const indexer = await this._indexer; + const opkey = indexer.lookup_id(dep_id); + + if (opkey) + { + cell.text(opkey).on_click((k) => this.view_opkey(k), opkey); + } + } + _find_iohash_field(container, name) { const found_field = container.find(name); -- cgit v1.2.3 From 425673a0230373a1b91c15c475f8e543ab246bce Mon Sep 17 00:00:00 2001 From: zousar Date: Tue, 17 Feb 2026 21:24:11 -0700 Subject: updatefrontend --- src/zenserver/frontend/html.zip | Bin 163229 -> 182962 bytes 1 file changed, 0 insertions(+), 0 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index 5d33302dd..67752fbc2 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ -- cgit v1.2.3 From b55fdf7c1dfe6d3e52b08a160a77472ec1480cf7 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Wed, 18 Feb 2026 08:54:05 +0100 Subject: convert ZEN_ASSERTs to exception to handle corrupt data gracefully (#760) * convert ZEN_ASSERTs to exception to handle corrupt data gracefully --- .../builds/buildstorageoperations.cpp | 68 ++++++++++++++++++---- .../zenremotestore/builds/buildstorageoperations.h | 3 +- 2 files changed, 58 insertions(+), 13 deletions(-) (limited to 'src') diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 2319ad66d..ade431393 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -4083,7 +4083,8 @@ BuildsOperationUpdateFolder::WriteSequenceChunkToCache(BufferedWriteFileCache::L } bool -BuildsOperationUpdateFolder::GetBlockWriteOps(std::span ChunkRawHashes, +BuildsOperationUpdateFolder::GetBlockWriteOps(const IoHash& BlockRawHash, + std::span ChunkRawHashes, std::span ChunkCompressedLengths, std::span> SequenceIndexChunksLeftToWriteCounters, std::span> RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4115,9 +4116,34 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span ChunkR uint64_t VerifyChunkSize; CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(ChunkMemoryView), VerifyChunkHash, VerifyChunkSize); - ZEN_ASSERT(CompressedChunk); - ZEN_ASSERT(VerifyChunkHash == ChunkHash); - ZEN_ASSERT(VerifyChunkSize == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]); + if (!CompressedChunk) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} is not a valid compressed buffer", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash)); + } + if (VerifyChunkHash != ChunkHash) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} has a mismatching content hash {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + VerifyChunkHash)); + } + if (VerifyChunkSize != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]) + { + throw std::runtime_error( + fmt::format("Chunk {} at {}, size {} in block {} has a mismatching raw size {}, expected {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + VerifyChunkSize, + m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])); + } OodleCompressor ChunkCompressor; OodleCompressionLevel ChunkCompressionLevel; @@ -4138,7 +4164,18 @@ BuildsOperationUpdateFolder::GetBlockWriteOps(std::span ChunkR { Decompressed = CompressedChunk.Decompress().AsIoBuffer(); } - ZEN_ASSERT(Decompressed.GetSize() == m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]); + + if (Decompressed.GetSize() != m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex]) + { + throw std::runtime_error(fmt::format("Chunk {} at {}, size {} in block {} decompressed to size {}, expected {}", + ChunkHash, + OffsetInBlock, + ChunkCompressedSize, + BlockRawHash, + Decompressed.GetSize(), + m_RemoteContent.ChunkedContent.ChunkRawSizes[ChunkIndex])); + } + ZEN_ASSERT_SLOW(ChunkHash == IoHash::HashBuffer(Decompressed)); for (const ChunkedContentLookup::ChunkSequenceLocation* Target : ChunkTargetPtrs) { @@ -4237,7 +4274,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription const std::vector ChunkCompressedLengths = ReadChunkBlockHeader(BlockView.Mid(CompressedBuffer::GetHeaderSizeForNoneEncoder()), HeaderSize); - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4252,7 +4290,8 @@ BuildsOperationUpdateFolder::WriteChunksBlockToCache(const ChunkBlockDescription return false; } - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, BlockDescription.ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -4283,7 +4322,8 @@ BuildsOperationUpdateFolder::WritePartialBlockChunksToCache(const ChunkBlockDesc const MemoryView BlockView = BlockMemoryBuffer.GetView(); BlockWriteOps Ops; - if (GetBlockWriteOps(BlockDescription.ChunkRawHashes, + if (GetBlockWriteOps(BlockDescription.BlockHash, + BlockDescription.ChunkRawHashes, BlockDescription.ChunkCompressedLengths, SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndexNeedsCopyFromSourceFlags, @@ -5334,6 +5374,13 @@ BuildsOperationUploadFolder::FetchChunk(const ChunkedFolderContent& Content, ZEN_ASSERT(!ChunkLocations.empty()); CompositeBuffer Chunk = OpenFileCache.GetRange(ChunkLocations[0].SequenceIndex, ChunkLocations[0].Offset, Content.ChunkedContent.ChunkRawSizes[ChunkIndex]); + if (!Chunk) + { + throw std::runtime_error(fmt::format("Unable to read chunk at {}, size {} from '{}'", + ChunkLocations[0].Offset, + Content.ChunkedContent.ChunkRawSizes[ChunkIndex], + Content.Paths[Lookup.SequenceIndexFirstPathIndex[ChunkLocations[0].SequenceIndex]])); + } ZEN_ASSERT_SLOW(IoHash::HashBuffer(Chunk) == ChunkHash); return Chunk; }; @@ -5362,10 +5409,7 @@ BuildsOperationUploadFolder::GenerateBlock(const ChunkedFolderContent& Content, Content.ChunkedContent.ChunkHashes[ChunkIndex], [this, &Content, &Lookup, &OpenFileCache, ChunkIndex](const IoHash& ChunkHash) -> std::pair { CompositeBuffer Chunk = FetchChunk(Content, Lookup, ChunkHash, OpenFileCache); - if (!Chunk) - { - ZEN_ASSERT(false); - } + ZEN_ASSERT(Chunk); uint64_t RawSize = Chunk.GetSize(); const bool ShouldCompressChunk = RawSize >= m_Options.MinimumSizeForCompressInBlock && diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 6304159ae..9e5bf8d91 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -339,7 +339,8 @@ private: const uint64_t FileOffset, const uint32_t PathIndex); - bool GetBlockWriteOps(std::span ChunkRawHashes, + bool GetBlockWriteOps(const IoHash& BlockRawHash, + std::span ChunkRawHashes, std::span ChunkCompressedLengths, std::span> SequenceIndexChunksLeftToWriteCounters, std::span> RemoteChunkIndexNeedsCopyFromSourceFlags, -- cgit v1.2.3 From ae9c30841074da9226a76c1eb2fb3a3e29086bf6 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 18 Feb 2026 09:40:35 +0100 Subject: add selective request logging support to http.sys (#762) * implemented selective request logging for http.sys for consistency with asio * fixed traversal of GetLogicalProcessorInformationEx to account for variable-sized records * also adds CPU usage metrics --- src/zencore/include/zencore/system.h | 1 + src/zencore/system.cpp | 169 +++++++++++++++++++++++++++-------- src/zenhttp/servers/httpsys.cpp | 25 +++++- 3 files changed, 156 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index aec2e0ce4..bf3c15d3d 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -25,6 +25,7 @@ struct SystemMetrics uint64_t AvailVirtualMemoryMiB = 0; uint64_t PageFileMiB = 0; uint64_t AvailPageFileMiB = 0; + float CpuUsagePercent = 0.0f; }; SystemMetrics GetSystemMetrics(); diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index e92691781..267c87e12 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -13,6 +13,8 @@ ZEN_THIRD_PARTY_INCLUDES_START # include # include +# include +# pragma comment(lib, "pdh.lib") ZEN_THIRD_PARTY_INCLUDES_END #elif ZEN_PLATFORM_LINUX # include @@ -65,55 +67,98 @@ GetSystemMetrics() // Determine physical core count - DWORD BufferSize = 0; - BOOL Result = GetLogicalProcessorInformationEx(RelationAll, nullptr, &BufferSize); - if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER) { - ThrowSystemError(Error, "Failed to get buffer size for logical processor information"); - } - - PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)Memory::Alloc(BufferSize); + DWORD BufferSize = 0; + BOOL Result = GetLogicalProcessorInformationEx(RelationAll, nullptr, &BufferSize); + if (int32_t Error = GetLastError(); Error != ERROR_INSUFFICIENT_BUFFER) + { + ThrowSystemError(Error, "Failed to get buffer size for logical processor information"); + } - Result = GetLogicalProcessorInformationEx(RelationAll, Buffer, &BufferSize); - if (!Result) - { - Memory::Free(Buffer); - throw std::runtime_error("Failed to get logical processor information"); - } + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX Buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)Memory::Alloc(BufferSize); - DWORD ProcessorPkgCount = 0; - DWORD ProcessorCoreCount = 0; - DWORD ByteOffset = 0; - while (ByteOffset + sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX) <= BufferSize) - { - const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX& Slpi = Buffer[ByteOffset / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)]; - if (Slpi.Relationship == RelationProcessorCore) + Result = GetLogicalProcessorInformationEx(RelationAll, Buffer, &BufferSize); + if (!Result) { - ProcessorCoreCount++; + Memory::Free(Buffer); + throw std::runtime_error("Failed to get logical processor information"); } - else if (Slpi.Relationship == RelationProcessorPackage) + + DWORD ProcessorPkgCount = 0; + DWORD ProcessorCoreCount = 0; + DWORD LogicalProcessorCount = 0; + + BYTE* Ptr = reinterpret_cast(Buffer); + BYTE* const End = Ptr + BufferSize; + while (Ptr < End) { - ProcessorPkgCount++; + const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX& Slpi = *reinterpret_cast(Ptr); + if (Slpi.Relationship == RelationProcessorCore) + { + ++ProcessorCoreCount; + + // Count logical processors (threads) across all processor groups for this core. + // Each core entry lists one GROUP_AFFINITY per group it spans; each set bit + // in the Mask represents one logical processor (HyperThreading sibling). + for (WORD g = 0; g < Slpi.Processor.GroupCount; ++g) + { + LogicalProcessorCount += static_cast(__popcnt64(Slpi.Processor.GroupMask[g].Mask)); + } + } + else if (Slpi.Relationship == RelationProcessorPackage) + { + ++ProcessorPkgCount; + } + Ptr += Slpi.Size; } - ByteOffset += sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX); - } - Metrics.CoreCount = ProcessorCoreCount; - Metrics.CpuCount = ProcessorPkgCount; + Metrics.CoreCount = ProcessorCoreCount; + Metrics.CpuCount = ProcessorPkgCount; + Metrics.LogicalProcessorCount = LogicalProcessorCount; - Memory::Free(Buffer); + Memory::Free(Buffer); + } // Query memory status - MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)}; - GlobalMemoryStatusEx(&MemStatus); + { + MEMORYSTATUSEX MemStatus{.dwLength = sizeof(MEMORYSTATUSEX)}; + GlobalMemoryStatusEx(&MemStatus); + + Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024; + Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024; + Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024; + Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024; + Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024; + Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; + } + + // Query CPU usage using PDH + // + // TODO: This should be changed to not require a Sleep, perhaps by using some + // background metrics gathering mechanism. + + { + PDH_HQUERY QueryHandle = nullptr; + PDH_HCOUNTER CounterHandle = nullptr; - Metrics.SystemMemoryMiB = MemStatus.ullTotalPhys / 1024 / 1024; - Metrics.AvailSystemMemoryMiB = MemStatus.ullAvailPhys / 1024 / 1024; - Metrics.VirtualMemoryMiB = MemStatus.ullTotalVirtual / 1024 / 1024; - Metrics.AvailVirtualMemoryMiB = MemStatus.ullAvailVirtual / 1024 / 1024; - Metrics.PageFileMiB = MemStatus.ullTotalPageFile / 1024 / 1024; - Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; + if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) + { + if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) == ERROR_SUCCESS) + { + PdhCollectQueryData(QueryHandle); + Sleep(100); + PdhCollectQueryData(QueryHandle); + + PDH_FMT_COUNTERVALUE CounterValue; + if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) + { + Metrics.CpuUsagePercent = static_cast(CounterValue.doubleValue); + } + } + PdhCloseQuery(QueryHandle); + } + } return Metrics; } @@ -190,6 +235,39 @@ GetSystemMetrics() } } + // Query CPU usage + Metrics.CpuUsagePercent = 0.0f; + if (FILE* Stat = fopen("/proc/stat", "r")) + { + char Line[256]; + unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; + static unsigned long PrevUser = 0, PrevNice = 0, PrevSystem = 0, PrevIdle = 0, PrevIoWait = 0, PrevIrq = 0, PrevSoftIrq = 0; + + if (fgets(Line, sizeof(Line), Stat)) + { + if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) + { + unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - + (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); + unsigned long IdleDelta = Idle - PrevIdle; + + if (TotalDelta > 0) + { + Metrics.CpuUsagePercent = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevUser = User; + PrevNice = Nice; + PrevSystem = System; + PrevIdle = Idle; + PrevIoWait = IoWait; + PrevIrq = Irq; + PrevSoftIrq = SoftIrq; + } + } + fclose(Stat); + } + // Get memory information long Pages = sysconf(_SC_PHYS_PAGES); long PageSize = sysconf(_SC_PAGE_SIZE); @@ -270,6 +348,25 @@ GetSystemMetrics() sysctlbyname("hw.packages", &Packages, &Size, nullptr, 0); Metrics.CpuCount = Packages > 0 ? Packages : 1; + // Query CPU usage using host_statistics64 + Metrics.CpuUsagePercent = 0.0f; + host_cpu_load_info_data_t CpuLoad; + mach_msg_type_number_t CpuCount = sizeof(CpuLoad) / sizeof(natural_t); + if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &CpuCount) == KERN_SUCCESS) + { + unsigned long TotalTicks = 0; + for (int i = 0; i < CPU_STATE_MAX; ++i) + { + TotalTicks += CpuLoad.cpu_ticks[i]; + } + + if (TotalTicks > 0) + { + unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; + Metrics.CpuUsagePercent = 100.0f * (TotalTicks - IdleTicks) / TotalTicks; + } + } + // Get memory information uint64_t MemSize = 0; Size = sizeof(MemSize); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 14896c803..c640ba90b 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -331,6 +331,8 @@ public: virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; virtual bool TryGetRanges(HttpRanges& Ranges) override; + void LogRequest(HttpMessageResponseRequest* Response); + using HttpServerRequest::WriteResponse; HttpSysServerRequest(const HttpSysServerRequest&) = delete; @@ -429,7 +431,8 @@ public: virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); // typically used for HEAD requests - inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + inline uint16_t GetResponseCode() const { return m_ResponseCode; } + inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } private: eastl::fixed_vector m_HttpDataChunks; @@ -1886,7 +1889,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (SuppressBody()) { @@ -1904,6 +1907,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) # endif SetIsHandled(); + LogRequest(Response); } void @@ -1913,7 +1917,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (SuppressBody()) { @@ -1931,6 +1935,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); +} + +void +HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response) +{ + if (ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", + ToString(RequestVerb()), + m_UriUtf8.c_str(), + Response->GetResponseCode(), + NiceBytes(Response->GetResponseBodySize())); + } } void @@ -1959,6 +1977,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); } void -- cgit v1.2.3 From 149a5c2faa8d59290b8b44717e504532e906aae2 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 18 Feb 2026 11:28:03 +0100 Subject: structured compute basics (#714) this change adds the `zencompute` component, which can be used to distribute work dispatched from UE using the DDB (Derived Data Build) APIs via zenserver this change also adds a distinct zenserver compute mode (`zenserver compute`) which is intended to be used for leaf compute nodes to exercise the compute functionality without directly involving UE, a `zen exec` subcommand is also added, which can be used to feed replays through the system all new functionality is considered *experimental* and disabled by default at this time, behind the `zencompute` option in xmake config --- src/zen/cmds/exec_cmd.cpp | 654 ++++++++++++++ src/zen/cmds/exec_cmd.h | 97 ++ src/zen/xmake.lua | 5 +- src/zen/zen.cpp | 39 +- src/zencompute-test/xmake.lua | 9 + src/zencompute-test/zencompute-test.cpp | 32 + src/zencompute/actionrecorder.cpp | 258 ++++++ src/zencompute/actionrecorder.h | 91 ++ src/zencompute/functionrunner.cpp | 112 +++ src/zencompute/functionrunner.h | 207 +++++ src/zencompute/functionservice.cpp | 957 ++++++++++++++++++++ src/zencompute/httpfunctionservice.cpp | 709 +++++++++++++++ src/zencompute/httporchestrator.cpp | 81 ++ .../include/zencompute/functionservice.h | 132 +++ .../include/zencompute/httpfunctionservice.h | 73 ++ .../include/zencompute/httporchestrator.h | 44 + .../include/zencompute/recordingreader.h | 127 +++ src/zencompute/include/zencompute/zencompute.h | 11 + src/zencompute/localrunner.cpp | 722 +++++++++++++++ src/zencompute/localrunner.h | 100 +++ src/zencompute/recordingreader.cpp | 335 +++++++ src/zencompute/remotehttprunner.cpp | 457 ++++++++++ src/zencompute/remotehttprunner.h | 80 ++ src/zencompute/xmake.lua | 11 + src/zencompute/zencompute.cpp | 12 + src/zennet/beacon.cpp | 170 ++++ src/zennet/include/zennet/beacon.h | 38 + src/zennet/include/zennet/statsdclient.h | 2 + src/zennet/statsdclient.cpp | 1 + src/zenserver-test/function-tests.cpp | 34 + src/zenserver/compute/computeserver.cpp | 330 +++++++ src/zenserver/compute/computeserver.h | 106 +++ src/zenserver/compute/computeservice.cpp | 100 +++ src/zenserver/compute/computeservice.h | 36 + src/zenserver/frontend/html/compute.html | 991 +++++++++++++++++++++ src/zenserver/main.cpp | 55 +- src/zenserver/storage/storageconfig.cpp | 1 + src/zenserver/storage/storageconfig.h | 1 + src/zenserver/storage/zenstorageserver.cpp | 21 + src/zenserver/storage/zenstorageserver.h | 26 +- src/zenserver/xmake.lua | 4 + src/zenserver/zenserver.cpp | 8 + src/zenserver/zenserver.h | 13 +- src/zentest-appstub/xmake.lua | 3 + src/zentest-appstub/zentest-appstub.cpp | 391 +++++++- 45 files changed, 7639 insertions(+), 47 deletions(-) create mode 100644 src/zen/cmds/exec_cmd.cpp create mode 100644 src/zen/cmds/exec_cmd.h create mode 100644 src/zencompute-test/xmake.lua create mode 100644 src/zencompute-test/zencompute-test.cpp create mode 100644 src/zencompute/actionrecorder.cpp create mode 100644 src/zencompute/actionrecorder.h create mode 100644 src/zencompute/functionrunner.cpp create mode 100644 src/zencompute/functionrunner.h create mode 100644 src/zencompute/functionservice.cpp create mode 100644 src/zencompute/httpfunctionservice.cpp create mode 100644 src/zencompute/httporchestrator.cpp create mode 100644 src/zencompute/include/zencompute/functionservice.h create mode 100644 src/zencompute/include/zencompute/httpfunctionservice.h create mode 100644 src/zencompute/include/zencompute/httporchestrator.h create mode 100644 src/zencompute/include/zencompute/recordingreader.h create mode 100644 src/zencompute/include/zencompute/zencompute.h create mode 100644 src/zencompute/localrunner.cpp create mode 100644 src/zencompute/localrunner.h create mode 100644 src/zencompute/recordingreader.cpp create mode 100644 src/zencompute/remotehttprunner.cpp create mode 100644 src/zencompute/remotehttprunner.h create mode 100644 src/zencompute/xmake.lua create mode 100644 src/zencompute/zencompute.cpp create mode 100644 src/zennet/beacon.cpp create mode 100644 src/zennet/include/zennet/beacon.h create mode 100644 src/zenserver-test/function-tests.cpp create mode 100644 src/zenserver/compute/computeserver.cpp create mode 100644 src/zenserver/compute/computeserver.h create mode 100644 src/zenserver/compute/computeservice.cpp create mode 100644 src/zenserver/compute/computeservice.h create mode 100644 src/zenserver/frontend/html/compute.html (limited to 'src') diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp new file mode 100644 index 000000000..2d9d0d12e --- /dev/null +++ b/src/zen/cmds/exec_cmd.cpp @@ -0,0 +1,654 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "exec_cmd.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace std::literals; + +namespace eastl { + +template<> +struct hash : public zen::IoHash::Hasher +{ +}; + +} // namespace eastl + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen { + +ExecCommand::ExecCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), ""); + m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), ""); + m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), ""); + m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), ""); + m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), ""); + m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), ""); + m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), ""); + m_Options.add_option("", + "", + "mode", + "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", + cxxopts::value(m_Mode)->default_value("http"), + ""); + m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), ""); + m_Options.parse_positional("mode"); +} + +ExecCommand::~ExecCommand() +{ +} + +void +ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + // Configure + + if (!ParseOptions(argc, argv)) + { + return; + } + + m_HostName = ResolveTargetHostSpec(m_HostName); + + if (m_RecordingPath.empty()) + { + throw OptionParseException("replay path is required!", m_Options.help()); + } + + m_VerboseLogging = GlobalOptions.IsVerbose; + m_QuietLogging = m_Quiet && !m_VerboseLogging; + + enum ExecMode + { + kHttp, + kDirect, + kInproc, + kDump, + kBeacon, + kBuildLog + } Mode; + + if (m_Mode == "http"sv) + { + Mode = kHttp; + } + else if (m_Mode == "direct"sv) + { + Mode = kDirect; + } + else if (m_Mode == "inproc"sv) + { + Mode = kInproc; + } + else if (m_Mode == "dump"sv) + { + Mode = kDump; + } + else if (m_Mode == "beacon"sv) + { + Mode = kBeacon; + } + else if (m_Mode == "buildlog"sv) + { + Mode = kBuildLog; + } + else + { + throw OptionParseException("invalid mode specified!", m_Options.help()); + } + + // Gather information from recording path + + std::unique_ptr Reader; + std::unique_ptr UeReader; + + std::filesystem::path RecordingPath{m_RecordingPath}; + + if (!std::filesystem::is_directory(RecordingPath)) + { + throw OptionParseException("replay path should be a directory path!", m_Options.help()); + } + else + { + if (std::filesystem::is_directory(RecordingPath / "cid")) + { + Reader = std::make_unique(RecordingPath); + m_WorkerMap = Reader->ReadWorkers(); + m_ChunkResolver = Reader.get(); + m_RecordingReader = Reader.get(); + } + else + { + UeReader = std::make_unique(RecordingPath); + m_WorkerMap = UeReader->ReadWorkers(); + m_ChunkResolver = UeReader.get(); + m_RecordingReader = UeReader.get(); + } + } + + ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); + + for (auto& Kv : m_WorkerMap) + { + CbObject WorkerDesc = Kv.second.GetObject(); + const IoHash& WorkerId = Kv.first; + + RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); + + if (m_VerboseLogging) + { + zen::ExtendableStringBuilder<1024> ObjStr; +# if 0 + zen::CompactBinaryToJson(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); +# else + zen::CompactBinaryToYaml(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); +# endif + } + } + + if (m_VerboseLogging) + { + EmitFunctionList(m_FunctionList); + } + + // Iterate over work items and dispatch or log them + + int ReturnValue = 0; + + Stopwatch ExecTimer; + + switch (Mode) + { + case kHttp: + // Forward requests to HTTP function service + ReturnValue = HttpExecute(); + break; + + case kDirect: + // Not currently supported + ReturnValue = LocalMessagingExecute(); + break; + + case kInproc: + // Handle execution in-core (by spawning child processes) + ReturnValue = InProcessExecute(); + break; + + case kDump: + // Dump high level information about actions to console + ReturnValue = DumpWorkItems(); + break; + + case kBeacon: + ReturnValue = BeaconExecute(); + break; + + case kBuildLog: + ReturnValue = BuildActionsLog(); + break; + + default: + ZEN_ERROR("Unknown operating mode! No work submitted"); + + ReturnValue = 1; + } + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } +} + +int +ExecCommand::InProcessExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + zen::compute::FunctionServiceSession FunctionSession(Resolver); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + FunctionSession.AddLocalRunner(Resolver, TempPath); + + return ExecUsingSession(FunctionSession); +} + +int +ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession) +{ + struct JobTracker + { + public: + inline void Insert(int LsnField) + { + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.insert(LsnField); + } + + inline bool IsEmpty() const + { + RwLock::SharedLockScope _(Lock); + return PendingJobs.empty(); + } + + inline void Remove(int CompleteLsn) + { + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.erase(CompleteLsn); + } + + inline size_t GetSize() const + { + RwLock::SharedLockScope _(Lock); + return PendingJobs.size(); + } + + private: + mutable RwLock Lock; + std::unordered_set PendingJobs; + }; + + JobTracker PendingJobs; + + std::atomic IsDraining{0}; + + auto DrainCompletedJobs = [&] { + if (IsDraining.exchange(1)) + { + return; + } + + auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); + + CbObjectWriter Cbo; + FunctionSession.GetCompleted(Cbo); + + if (CbObject Completed = Cbo.Save()) + { + for (auto& It : Completed["completed"sv]) + { + int32_t CompleteLsn = It.AsInt32(); + + CbPackage ResultPackage; + HttpResponseCode Response = FunctionSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + + if (Response == HttpResponseCode::OK) + { + PendingJobs.Remove(CompleteLsn); + + ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); + } + } + } + }; + + // Describe workers + + ZEN_CONSOLE("describing {} workers", m_WorkerMap.size()); + + for (auto Kv : m_WorkerMap) + { + CbPackage WorkerDesc = Kv.second; + + FunctionSession.RegisterWorker(WorkerDesc); + } + + // Then submit work items + + int FailedWorkCounter = 0; + size_t RemainingWorkItems = m_RecordingReader->GetActionCount(); + int SubmittedWorkItems = 0; + + ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + + int OffsetCounter = m_Offset; + int StrideCounter = m_Stride; + + auto ShouldSchedule = [&]() -> bool { + if (m_Limit && SubmittedWorkItems >= m_Limit) + { + // Limit reached, ignore + + return false; + } + + if (OffsetCounter && OffsetCounter--) + { + // Still in offset, ignore + + return false; + } + + if (--StrideCounter == 0) + { + StrideCounter = m_Stride; + + return true; + } + + return false; + }; + + m_RecordingReader->IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + // Enqueue job + + Stopwatch SubmitTimer; + + const int Priority = 0; + + if (ShouldSchedule()) + { + if (m_VerboseLogging) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + eastl::hash_set ReferencedChunks; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ReferencedChunks.insert(AttachData); + ++AttachmentCount; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", + ActionId, + AttachmentCount, + NiceBytes(AttachmentBytes), + ObjStr); + } + + if (zen::compute::FunctionServiceSession::EnqueueResult EnqueueResult = + FunctionSession.EnqueueAction(ActionObject, Priority)) + { + const int32_t LsnField = EnqueueResult.Lsn; + + --RemainingWorkItems; + ++SubmittedWorkItems; + + if (!m_QuietLogging) + { + ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", + SubmittedWorkItems, + LsnField, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + RemainingWorkItems); + } + + PendingJobs.Insert(LsnField); + } + else + { + if (!m_QuietLogging) + { + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + ZEN_ERROR( + "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " + "descriptor " + "at: 'file://{}'", + std::string(FunctionName), + FunctionVersion, + BuildSystemVersion, + ""); + + EmitFunctionListOnce(m_FunctionList); + } + + ++FailedWorkCounter; + } + } + + // Check for completed work + + DrainCompletedJobs(); + }, + 8); + + // Wait until all pending work is complete + + while (!PendingJobs.IsEmpty()) + { + // TODO: improve this logic + zen::Sleep(500); + + DrainCompletedJobs(); + } + + if (FailedWorkCounter) + { + return 1; + } + + return 0; +} + +int +ExecCommand::LocalMessagingExecute() +{ + // Non-HTTP work submission path + + // To be reimplemented using final transport + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +int +ExecCommand::HttpExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::FunctionServiceSession FunctionSession(Resolver); + FunctionSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + + return ExecUsingSession(FunctionSession); +} + +int +ExecCommand::BeaconExecute() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::FunctionServiceSession FunctionSession(Resolver); + FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); + // FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://10.99.9.246:8558"); + + return ExecUsingSession(FunctionSession); +} + +////////////////////////////////////////////////////////////////////////// + +void +ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) +{ + const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerDesc["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } +} + +void +ExecCommand::EmitFunctionListOnce(const std::vector& FunctionList) +{ + if (m_FunctionListEmittedOnce == false) + { + EmitFunctionList(FunctionList); + + m_FunctionListEmittedOnce = true; + } +} + +int +ExecCommand::DumpWorkItems() +{ + std::atomic EmittedCount{0}; + + eastl::hash_map SeenAttachments; // Attachment CID -> count of references + + m_RecordingReader->IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + eastl::hash_map Attachments; + + uint64_t AttachmentBytes = 0; + uint64_t UncompressedAttachmentBytes = 0; + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + Attachments[AttachmentCid] = CompressedData; + + AttachmentBytes += CompressedData.GetCompressedSize(); + UncompressedAttachmentBytes += CompressedData.DecodeRawSize(); + + if (auto [Iter, Inserted] = SeenAttachments.insert({AttachmentCid, 1}); !Inserted) + { + ++Iter->second; + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + +# if 0 + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments): {}", ActionId, Attachments.size(), ObjStr); +# else + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {}->{} bytes):\n{}", + ActionId, + Attachments.size(), + AttachmentBytes, + UncompressedAttachmentBytes, + ObjStr); +# endif + + ++EmittedCount; + }, + 1); + + ZEN_CONSOLE("emitted: {} actions", EmittedCount.load()); + + eastl::map> ReferenceHistogram; + + for (const auto& [K, V] : SeenAttachments) + { + if (V > 1) + { + ReferenceHistogram[V].push_back(K); + } + } + + for (const auto& [RefCount, Cids] : ReferenceHistogram) + { + ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount); + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +int +ExecCommand::BuildActionsLog() +{ + ZEN_ASSERT(m_ChunkResolver); + ChunkResolver& Resolver = *m_ChunkResolver; + + if (m_RecordingPath.empty()) + { + throw OptionParseException("need to specify recording path", m_Options.help()); + } + + if (std::filesystem::exists(m_RecordingLogPath)) + { + throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help()); + } + + ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::FunctionServiceSession FunctionSession(Resolver); + FunctionSession.StartRecording(Resolver, m_RecordingLogPath); + + return ExecUsingSession(FunctionSession); +} + +void +ExecCommand::EmitFunctionList(const std::vector& FunctionList) +{ + ZEN_CONSOLE("=== Known functions:\n==========================="); + + ZEN_CONSOLE("{:30} {:36} {:36} {}", "function", "version", "build system", "worker id"); + + for (const FunctionDefinition& Func : FunctionList) + { + ZEN_CONSOLE("{:30} {:36} {:36} {}", Func.FunctionName, Func.FunctionVersion, Func.BuildSystemVersion, Func.WorkerId); + } + + ZEN_CONSOLE("==========================="); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h new file mode 100644 index 000000000..43d092144 --- /dev/null +++ b/src/zen/cmds/exec_cmd.h @@ -0,0 +1,97 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include +#include +#include +#include + +#include +#include +#include + +namespace zen { +class CbPackage; +class CbObject; +struct IoHash; +class ChunkResolver; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { +class FunctionServiceSession; +} + +namespace zen { + +/** + * Zen CLI command for executing functions from a recording + * + * Mostly for testing and debugging purposes + */ + +class ExecCommand : public ZenCmdBase +{ +public: + ExecCommand(); + ~ExecCommand(); + + static constexpr char Name[] = "exec"; + static constexpr char Description[] = "Execute functions from a recording"; + + virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{Name, Description}; + std::string m_HostName; + std::filesystem::path m_BeaconPath; + std::filesystem::path m_RecordingPath; + std::filesystem::path m_RecordingLogPath; + int m_Offset = 0; + int m_Stride = 1; + int m_Limit = 0; + bool m_Quiet = false; + std::string m_Mode{"http"}; + + struct FunctionDefinition + { + std::string FunctionName; + zen::Guid FunctionVersion; + zen::Guid BuildSystemVersion; + zen::IoHash WorkerId; + }; + + bool m_FunctionListEmittedOnce = false; + void EmitFunctionListOnce(const std::vector& FunctionList); + void EmitFunctionList(const std::vector& FunctionList); + + std::unordered_map m_WorkerMap; + std::vector m_FunctionList; + bool m_VerboseLogging = false; + bool m_QuietLogging = false; + + zen::ChunkResolver* m_ChunkResolver = nullptr; + zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; + + void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); + + int ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession); + + // Execution modes + + int DumpWorkItems(); + int HttpExecute(); + int InProcessExecute(); + int LocalMessagingExecute(); + int BeaconExecute(); + int BuildActionsLog(); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua index ab094fef3..f889c3296 100644 --- a/src/zen/xmake.lua +++ b/src/zen/xmake.lua @@ -6,15 +6,12 @@ target("zen") add_files("**.cpp") add_files("zen.cpp", {unity_ignored = true }) add_deps("zencore", "zenhttp", "zenremotestore", "zenstore", "zenutil") + add_deps("zencompute", "zennet") add_deps("cxxopts", "fmt") add_packages("json11") add_includedirs(".") set_symbols("debug") - if is_mode("release") then - set_optimize("fastest") - end - if is_plat("windows") then add_files("zen.rc") add_ldflags("/subsystem:console,5.02") diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 25245c3d2..018f77738 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -11,6 +11,7 @@ #include "cmds/cache_cmd.h" #include "cmds/copy_cmd.h" #include "cmds/dedup_cmd.h" +#include "cmds/exec_cmd.h" #include "cmds/info_cmd.h" #include "cmds/print_cmd.h" #include "cmds/projectstore_cmd.h" @@ -316,22 +317,25 @@ main(int argc, char** argv) } #endif // ZEN_WITH_TRACE - AttachCommand AttachCmd; - BenchCommand BenchCmd; - BuildsCommand BuildsCmd; - CacheDetailsCommand CacheDetailsCmd; - CacheGetCommand CacheGetCmd; - CacheGenerateCommand CacheGenerateCmd; - CacheInfoCommand CacheInfoCmd; - CacheStatsCommand CacheStatsCmd; - CopyCommand CopyCmd; - CopyStateCommand CopyStateCmd; - CreateOplogCommand CreateOplogCmd; - CreateProjectCommand CreateProjectCmd; - DedupCommand DedupCmd; - DownCommand DownCmd; - DropCommand DropCmd; - DropProjectCommand ProjectDropCmd; + AttachCommand AttachCmd; + BenchCommand BenchCmd; + BuildsCommand BuildsCmd; + CacheDetailsCommand CacheDetailsCmd; + CacheGetCommand CacheGetCmd; + CacheGenerateCommand CacheGenerateCmd; + CacheInfoCommand CacheInfoCmd; + CacheStatsCommand CacheStatsCmd; + CopyCommand CopyCmd; + CopyStateCommand CopyStateCmd; + CreateOplogCommand CreateOplogCmd; + CreateProjectCommand CreateProjectCmd; + DedupCommand DedupCmd; + DownCommand DownCmd; + DropCommand DropCmd; + DropProjectCommand ProjectDropCmd; +#if ZEN_WITH_COMPUTE_SERVICES + ExecCommand ExecCmd; +#endif // ZEN_WITH_COMPUTE_SERVICES ExportOplogCommand ExportOplogCmd; FlushCommand FlushCmd; GcCommand GcCmd; @@ -388,6 +392,9 @@ main(int argc, char** argv) {"dedup", &DedupCmd, "Dedup files"}, {"down", &DownCmd, "Bring zen server down"}, {"drop", &DropCmd, "Drop cache namespace or bucket"}, +#if ZEN_WITH_COMPUTE_SERVICES + {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, +#endif {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"}, {"gc-stop", &GcStopCmd, "Request cancel of running garbage collection in zen storage"}, {"gc", &GcCmd, "Garbage collect zen storage"}, diff --git a/src/zencompute-test/xmake.lua b/src/zencompute-test/xmake.lua new file mode 100644 index 000000000..64a3c7703 --- /dev/null +++ b/src/zencompute-test/xmake.lua @@ -0,0 +1,9 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zencompute-test") + set_kind("binary") + set_group("tests") + add_headerfiles("**.h") + add_files("*.cpp") + add_deps("zencompute", "zencore") + add_packages("vcpkg::doctest") diff --git a/src/zencompute-test/zencompute-test.cpp b/src/zencompute-test/zencompute-test.cpp new file mode 100644 index 000000000..237812e12 --- /dev/null +++ b/src/zencompute-test/zencompute-test.cpp @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include +#include +#include + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include +# include +# include +#endif + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include +#endif + +int +main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) +{ +#if ZEN_WITH_TESTS + zen::zencompute_forcelinktests(); + + zen::logging::InitializeLogging(); + zen::MaximizeOpenFileCount(); + + return ZEN_RUN_TESTS(argc, argv); +#else + return 0; +#endif +} diff --git a/src/zencompute/actionrecorder.cpp b/src/zencompute/actionrecorder.cpp new file mode 100644 index 000000000..04c4b5141 --- /dev/null +++ b/src/zencompute/actionrecorder.cpp @@ -0,0 +1,258 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "actionrecorder.h" + +#include "functionrunner.h" + +#include +#include +#include +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RecordingFileWriter::RecordingFileWriter() +{ +} + +RecordingFileWriter::~RecordingFileWriter() +{ + Close(); +} + +void +RecordingFileWriter::Open(std::filesystem::path FilePath) +{ + using namespace std::literals; + + m_File.Open(FilePath, BasicFile::Mode::kTruncate); + m_File.Write("----DDC2----DATA", 16, 0); + m_FileOffset = 16; + + std::filesystem::path TocPath = FilePath.replace_extension(".ztoc"); + m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate); + + m_TocWriter << "version"sv << 1; + m_TocWriter.BeginArray("toc"sv); +} + +void +RecordingFileWriter::Close() +{ + m_TocWriter.EndArray(); + CbObject Toc = m_TocWriter.Save(); + + std::error_code Ec; + m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec); +} + +void +RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash) +{ + RwLock::ExclusiveLockScope _(m_FileLock); + + MemoryView ObjectView = Object.GetBuffer().GetView(); + + std::error_code Ec; + m_File.Write(ObjectView, m_FileOffset, Ec); + + if (Ec) + { + throw std::system_error(Ec, "failed writing to archive"); + } + + m_TocWriter.BeginArray(); + m_TocWriter.AddHash(ObjectHash); + m_TocWriter.AddInteger(m_FileOffset); + m_TocWriter.AddInteger(gsl::narrow(ObjectView.GetSize())); + m_TocWriter.EndArray(); + + m_FileOffset += ObjectView.GetSize(); +} + +////////////////////////////////////////////////////////////////////////// + +ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath) +: m_ChunkResolver(InChunkResolver) +, m_RecordingLogDir(RecordingLogPath) +{ + std::error_code Ec; + CreateDirectories(m_RecordingLogDir, Ec); + + if (Ec) + { + ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat"); + m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat"); + + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +ActionRecorder::~ActionRecorder() +{ + Shutdown(); +} + +void +ActionRecorder::Shutdown() +{ + m_CidStore.Flush(); +} + +void +ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId); + + std::unordered_set AddedChunks; + uint64_t AddedBytes = 0; + + // First add all attachments from the worker package itself + + for (const CbAttachment& Attachment : WorkerPackage.GetAttachments()) + { + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer(); + + const IoHash ChunkHash = Buffer.DecodeRawHash(); + + CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly); + + AddedChunks.insert(ChunkHash); + + if (Result.New) + { + AddedBytes += Data.GetSize(); + } + } + + // Not all attachments will be present in the worker package, so we need to add + // all referenced chunks to ensure that the recording is self-contained and not + // referencing data in the main CID store + + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + + WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + + if (!AddedChunks.contains(AttachmentCid)) + { + IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly); + + if (Result.New) + { + AddedBytes += AttachmentData.GetSize(); + } + } + else + { + ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId); + } + + AddedChunks.insert(AttachmentCid); + } + }); + + ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes); +} + +bool +ActionRecorder::RecordAction(Ref Action) +{ + bool AllGood = true; + + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsHash(); + IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData); + + if (ChunkData) + { + if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash DecompressedHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize); + + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + uint64_t BlockSize = 0; + if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize)) + { + if (Compressor == OodleCompressor::NotSet) + { + CompositeBuffer Decompressed = Compressed.DecompressToComposite(); + CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed), + OodleCompressor::Mermaid, + OodleCompressionLevel::Fast, + BlockSize); + + ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer(); + } + } + } + + const uint64_t ChunkSize = ChunkData.GetSize(); + + m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly); + ++m_ChunkCounter; + m_ChunkBytesCounter.fetch_add(ChunkSize); + } + else + { + AllGood = false; + + ZEN_WARN("could not resolve chunk {}", AttachData); + } + }); + + if (AllGood) + { + m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId); + ++m_ActionsCounter; + + return true; + } + else + { + return false; + } +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/actionrecorder.h b/src/zencompute/actionrecorder.h new file mode 100644 index 000000000..9cc2b44a2 --- /dev/null +++ b/src/zencompute/actionrecorder.h @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +struct RecordingFileWriter +{ + RecordingFileWriter(RecordingFileWriter&&) = delete; + RecordingFileWriter& operator=(RecordingFileWriter&&) = delete; + + RwLock m_FileLock; + BasicFile m_File; + uint64_t m_FileOffset = 0; + CbObjectWriter m_TocWriter; + BasicFile m_TocFile; + + RecordingFileWriter(); + ~RecordingFileWriter(); + + void Open(std::filesystem::path FilePath); + void Close(); + void AppendObject(const CbObject& Object, const IoHash& ObjectHash); +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Recording "runner" implementation + * + * This class writes out all actions and their attachments to a recording directory + * in a format that can be read back by the RecordingReader. + * + * The contents of the recording directory will be self-contained, with all referenced + * attachments stored in the recording directory itself, so that the recording can be + * moved or shared without needing to maintain references to the main CID store. + * + */ + +class ActionRecorder +{ +public: + ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath); + ~ActionRecorder(); + + ActionRecorder(const ActionRecorder&) = delete; + ActionRecorder& operator=(const ActionRecorder&) = delete; + + void Shutdown(); + void RegisterWorker(const CbPackage& WorkerPackage); + bool RecordAction(Ref Action); + +private: + ChunkResolver& m_ChunkResolver; + std::filesystem::path m_RecordingLogDir; + + RecordingFileWriter m_WorkersFile; + RecordingFileWriter m_ActionsFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + std::atomic m_ChunkCounter{0}; + std::atomic m_ChunkBytesCounter{0}; + std::atomic m_ActionsCounter{0}; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/functionrunner.cpp b/src/zencompute/functionrunner.cpp new file mode 100644 index 000000000..8e7c12b2b --- /dev/null +++ b/src/zencompute/functionrunner.cpp @@ -0,0 +1,112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include + +# include +# include + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector +FunctionRunner::SubmitActions(const std::vector>& Actions) +{ + std::vector Results; + Results.reserve(Actions.size()); + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(FunctionServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/functionrunner.h b/src/zencompute/functionrunner.h new file mode 100644 index 000000000..6fd0d84cc --- /dev/null +++ b/src/zencompute/functionrunner.h @@ -0,0 +1,207 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions); + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +template +struct RunnerGroup +{ + void AddRunner(RunnerType* Runner) + { + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); + } + size_t QueryCapacity() + { + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; + } + + SubmitResult SubmitAction(Ref Action) + { + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; + } + + size_t GetSubmittedActionCount() + { + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; + } + + void RegisterWorker(CbPackage Worker) + { + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } + } + + void Shutdown() + { + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } + } + +private: + RwLock m_RunnersLock; + std::vector> m_Runners; + std::atomic m_NextSubmitIndex{0}; +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(FunctionServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + + enum class State + { + New, + Pending, + Running, + Completed, + Failed, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + default: + return "Unknown"; + } + } + + uint64_t Timestamps[static_cast(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool IsCompleted() const { return ActionState() == State::Completed || ActionState() == State::Failed; } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + +private: + std::atomic m_ActionState = State::New; + FunctionServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/functionservice.cpp b/src/zencompute/functionservice.cpp new file mode 100644 index 000000000..0698449e9 --- /dev/null +++ b/src/zencompute/functionservice.cpp @@ -0,0 +1,957 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/functionservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" +# include "actionrecorder.h" +# include "localrunner.h" +# include "remotehttprunner.h" + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +struct FunctionServiceSession::Impl +{ + FunctionServiceSession* m_FunctionServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("apply")}; + + Impl(FunctionServiceSession* InFunctionServiceSession, ChunkResolver& InChunkResolver) + : m_FunctionServiceSession(InFunctionServiceSession) + , m_ChunkResolver(InChunkResolver) + { + m_SchedulingThread = std::thread{&Impl::MonitorThreadFunction, this}; + } + + void Shutdown(); + bool IsHealthy(); + + LoggerRef Log() { return m_Log; } + + std::atomic_bool m_AcceptActions = true; + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void EmitStats(CbObjectWriter& Cbo) + { + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + std::atomic m_ActionsCounter = 0; // sequence number + + RwLock m_PendingLock; + std::map> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic m_RetiredCount{0}; + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + + std::atomic m_ShutdownRequested{false}; + + std::thread m_SchedulingThread; + std::atomic m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void MonitorThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map m_WorkerMap; + std::vector m_FunctionList; + std::vector GetKnownWorkerIds(); + + // Runners + + RunnerGroup m_LocalRunnerGroup; + RunnerGroup m_RemoteRunnerGroup; + + EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque m_ActionHistory; + size_t m_HistoryLimit = 1000; + + std::vector GetActionHistory(int Limit); + + // + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action); + [[nodiscard]] std::vector SubmitActions(const std::vector>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + void ShutdownRunners(); +}; + +bool +FunctionServiceSession::Impl::IsHealthy() +{ + return true; +} + +void +FunctionServiceSession::Impl::Shutdown() +{ + m_AcceptActions = false; + m_ShutdownRequested = true; + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); +} + +void +FunctionServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +FunctionServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +FunctionServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector +FunctionServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast(Limit) < m_ActionHistory.size()) + { + return std::vector(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +void +FunctionServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +WorkerDesc +FunctionServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector +FunctionServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector WorkerIds; + WorkerIds.reserve(m_WorkerMap.size()); + + m_WorkerLock.WithSharedLock([&] { + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +FunctionServiceSession::EnqueueResult +FunctionServiceSession::Impl::EnqueueAction(CbObject ActionObject, int Priority) +{ + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + CbPackage WorkerPackage = It->second; + + return EnqueueResolvedAction(WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +FunctionServiceSession::EnqueueResult +FunctionServiceSession::Impl::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + const int ActionLsn = ++m_ActionsCounter; + + Ref Pending{new RunnerAction(m_FunctionServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + SubmitResult SubResult = SubmitAction(Pending); + + if (SubResult.IsAccepted) + { + // Great, the job is being taken care of by the runner + ZEN_DEBUG("direct schedule LSN {}", Pending->ActionLsn); + } + else + { + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + } + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +FunctionServiceSession::Impl::SubmitAction(Ref Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +FunctionServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +FunctionServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +FunctionServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +FunctionServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& Kv : m_ResultsMap) + { + Cbo << Kv.first; + } + }); + + Cbo.EndArray(); +} + +# define ZEN_BATCH_SCHEDULER 1 + +void +FunctionServiceSession::Impl::SchedulePendingActions() +{ + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + +# if ZEN_BATCH_SCHEDULER + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will try to submit these to the runner outside of the lock + + m_PendingLock.WithExclusiveLock([&] { + if (m_ShutdownRequested) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + size_t NumActionsToSchedule = std::min(Capacity, m_PendingActions.size()); + + auto PendingIt = m_PendingActions.begin(); + const auto PendingEnd = m_PendingActions.end(); + + while (NumActionsToSchedule && PendingIt != PendingEnd) + { + const Ref& Pending = PendingIt->second; + + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast(Pending->ActionState()), Pending->ActionLsn); + break; + } + + ++PendingIt; + --NumActionsToSchedule; + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + auto SubmitResults = SubmitActions(ActionsToSchedule); + + // Move successfully scheduled actions to the running map and remove + // from pending queue. It's actually possible that by the time we get + // to this stage some of the actions may have already completed, so + // they should not always be added to the running map + + eastl::hash_set ScheduledActions; + + for (size_t i = 0; i < ActionsToSchedule.size(); ++i) + { + const Ref& Pending = ActionsToSchedule[i]; + const SubmitResult& SubResult = SubmitResults[i]; + + if (SubResult.IsAccepted) + { + ScheduledActions.insert(Pending->ActionLsn); + } + } + + ScheduledCount += (int)ActionsToSchedule.size(); + +# else + m_PendingLock.WithExclusiveLock([&] { + while (!m_PendingActions.empty()) + { + if (m_ShutdownRequested) + { + return; + } + + // Here it would be good if we could decide to pop immediately to avoid + // holding the lock while creating processes etc + const Ref& Pending = m_PendingActions.begin()->second; + FunctionRunner::SubmitResult SubResult = SubmitAction(Pending); + + if (SubResult.IsAccepted) + { + // Great, the job is being taken care of by the runner + + ZEN_DEBUG("action {} ({}) PENDING -> RUNNING", Pending->ActionId, Pending->ActionLsn); + + m_RunningLock.WithExclusiveLock([&] { + m_RunningMap.insert({Pending->ActionLsn, Pending}); + + RunningCount = m_RunningMap.size(); + }); + + m_PendingActions.pop_front(); + + PendingCount = m_PendingActions.size(); + ++ScheduledCount; + } + else + { + // Runner could not accept the job, leave it on the pending queue + + return; + } + } + }); +# endif +} + +void +FunctionServiceSession::Impl::MonitorThreadFunction() +{ + SetCurrentThreadName("FunctionServiceSession_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + int TimeoutMs = 1000; + + if (m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); })) + { + TimeoutMs = 100; + } + + const bool Timedout = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + HandleActionUpdates(); + + // Schedule pending actions + + SchedulePendingActions(); + + if (!Timedout) + { + m_SchedulingThreadEvent.Reset(); + } + } while (m_SchedulingThreadEnabled); +} + +void +FunctionServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); +} + +void +FunctionServiceSession::Impl::HandleActionUpdates() +{ + std::vector> UpdatedActions; + + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set SeenLsn; + std::unordered_set RunningLsn; + + for (Ref& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + case RunnerAction::State::Running: + m_PendingLock.WithExclusiveLock([&] { + m_RunningLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + m_PendingLock.WithExclusiveLock([&] { + m_RunningLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .Succeeded = Action->ActionState() == RunnerAction::State::Completed}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + Action->ActionState() == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + break; + } + } + } +} + +size_t +FunctionServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector +FunctionServiceSession::Impl::SubmitActions(const std::vector>& Actions) +{ + std::vector Results; + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +FunctionServiceSession::FunctionServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique(this, InChunkResolver); +} + +FunctionServiceSession::~FunctionServiceSession() +{ + Shutdown(); +} + +bool +FunctionServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +FunctionServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +void +FunctionServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +FunctionServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +void +FunctionServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector +FunctionServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +FunctionServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +FunctionServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath) +{ + m_Impl->m_LocalRunnerGroup.AddRunner(new LocalProcessRunner(InChunkResolver, BasePath)); +} + +void +FunctionServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + m_Impl->m_RemoteRunnerGroup.AddRunner(new RemoteHttpRunner(InChunkResolver, BasePath, HostName)); +} + +FunctionServiceSession::EnqueueResult +FunctionServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueAction(ActionObject, Priority); +} + +FunctionServiceSession::EnqueueResult +FunctionServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedAction(Worker, ActionObj, RequestPriority); +} + +void +FunctionServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +FunctionServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +FunctionServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +std::vector +FunctionServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +void +FunctionServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +FunctionServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +function_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpfunctionservice.cpp b/src/zencompute/httpfunctionservice.cpp new file mode 100644 index 000000000..09a9684a7 --- /dev/null +++ b/src/zencompute/httpfunctionservice.cpp @@ -0,0 +1,709 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpfunctionservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; + +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; + +HttpFunctionService::HttpFunctionService(CidStore& InCidStore, + IHttpStatsService& StatsService, + [[maybe_unused]] const std::filesystem::path& BaseDir) +: m_CidStore(InCidStore) +, m_StatsService(StatsService) +, m_Log(logging::Get("apply")) +, m_BaseDir(BaseDir) +, m_FunctionService(InCidStore) +{ + m_FunctionService.AddLocalRunner(InCidStore, m_BaseDir / "local"); + + m_StatsService.RegisterHandler("apply", *this); + + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_FunctionService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_FunctionService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_FunctionService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + // Determine which pieces are missing and need to be transmitted + + HashKeySet ChunkSet; + + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_FunctionService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_FunctionService.RegisterWorker(WorkerSpecPackage); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_FunctionService.GetCompleted(Cbo); + + SystemMetrics Sm = GetSystemMetricsForReporting(); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_FunctionService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = std::stoi(std::string{Req.GetCapture(1)}); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_FunctionService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + + return HttpReq.WriteResponse(ResponseCode); + } + break; + + case HttpVerb::kPost: + { + // Add support for cancellation, priority changes + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_FunctionService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_FunctionService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (FunctionServiceSession::EnqueueResult Result = + m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (FunctionServiceSession::EnqueueResult Result = + m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::vector WorkerIds = m_FunctionService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + + Cbo << "id" << WorkerId; + + const auto& Descriptor = m_FunctionService.GetWorkerDescriptor(WorkerId); + + Cbo << "descriptor" << Descriptor.Descriptor.GetObject(); + + Cbo.EndObject(); + } + + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + SystemMetrics Sm = GetSystemMetricsForReporting(); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + m_FunctionService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + m_FunctionService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); +} + +HttpFunctionService::~HttpFunctionService() +{ + m_StatsService.UnregisterHandler("apply", *this); +} + +void +HttpFunctionService::Shutdown() +{ + m_FunctionService.Shutdown(); +} + +const char* +HttpFunctionService::BaseUri() const +{ + return "/apply/"; +} + +void +HttpFunctionService::HandleRequest(HttpServerRequest& Request) +{ + metrics::OperationTiming::Scope $(m_HttpRequests); + + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpFunctionService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_FunctionService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +httpfunction_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp new file mode 100644 index 000000000..39e7e60d7 --- /dev/null +++ b/src/zencompute/httporchestrator.cpp @@ -0,0 +1,81 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httporchestrator.h" + +#include +#include + +namespace zen::compute { + +HttpOrchestratorService::HttpOrchestratorService() : m_Log(logging::Get("orch")) +{ + m_Router.RegisterRoute( + "provision", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "uri" << Worker.BaseUri; + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObject Data = HttpReq.ReadPayloadObject(); + + std::string_view WorkerId = Data["id"].AsString(""); + std::string_view WorkerUri = Data["uri"].AsString(""); + + if (WorkerId.empty() || WorkerUri.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto& Worker = m_KnownWorkers[std::string(WorkerId)]; + Worker.BaseUri = WorkerUri; + Worker.LastSeen.Reset(); + }); + + HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); +} + +HttpOrchestratorService::~HttpOrchestratorService() +{ +} + +const char* +HttpOrchestratorService::BaseUri() const +{ + return "/orch/"; +} + +void +HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/functionservice.h b/src/zencompute/include/zencompute/functionservice.h new file mode 100644 index 000000000..1deb99fd5 --- /dev/null +++ b/src/zencompute/include/zencompute/functionservice.h @@ -0,0 +1,132 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include + +# include + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class FunctionServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class FunctionServiceSession final +{ +public: + FunctionServiceSession(ChunkResolver& InChunkResolver); + ~FunctionServiceSession(); + + void Shutdown(); + bool IsHealthy(); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + + void GetCompleted(CbWriter&); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + bool Succeeded; + uint64_t Timestamps[5] = {}; + }; + + [[nodiscard]] std::vector GetActionHistory(int Limit = 100); + + // Stats reporting + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr m_Impl; +}; + +void function_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpfunctionservice.h b/src/zencompute/include/zencompute/httpfunctionservice.h new file mode 100644 index 000000000..6e2344ae6 --- /dev/null +++ b/src/zencompute/include/zencompute/httpfunctionservice.h @@ -0,0 +1,73 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/functionservice.h" + +# include +# include +# include +# include +# include +# include + +# include +# include +# include + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +class HttpFunctionService; +class FunctionService; + +/** + * HTTP interface for compute function service + */ +class HttpFunctionService : public HttpService, public IHttpStatsProvider +{ +public: + HttpFunctionService(CidStore& InCidStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir); + ~HttpFunctionService(); + + void Shutdown(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + +protected: + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef Log() { return m_Log; } + +private: + LoggerRef m_Log; + std::filesystem ::path m_BaseDir; + HttpRequestRouter m_Router; + FunctionServiceSession m_FunctionService; + + // Metrics + + metrics::OperationTiming m_HttpRequests; +}; + +void httpfunction_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h new file mode 100644 index 000000000..168c6d7fe --- /dev/null +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace zen::compute { + +/** + * Mock orchestrator service, for testing dynamic provisioning + */ + +class HttpOrchestratorService : public HttpService +{ +public: + HttpOrchestratorService(); + ~HttpOrchestratorService(); + + HttpOrchestratorService(const HttpOrchestratorService&) = delete; + HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + +private: + HttpRequestRouter m_Router; + LoggerRef m_Log; + + struct KnownWorker + { + std::string_view BaseUri; + Stopwatch LastSeen; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map m_KnownWorkers; +}; + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h new file mode 100644 index 000000000..bf1aff125 --- /dev/null +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +class RecordingReaderBase +{ + RecordingReaderBase(const RecordingReaderBase&) = delete; + RecordingReaderBase& operator=(const RecordingReaderBase&) = delete; + +public: + RecordingReaderBase() = default; + virtual ~RecordingReaderBase() = 0; + virtual std::unordered_map ReadWorkers() = 0; + virtual void IterateActions(std::function&& Callback, int TargetParallelism) = 0; + virtual size_t GetActionCount() const = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Reader for recordings done via the zencompute recording system, which + * have a shared chunk store and a log of actions with pointers into the + * chunk store for their data. + */ +class RecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit RecordingReader(const std::filesystem::path& RecordingPath); + ~RecordingReader(); + + virtual std::unordered_map ReadWorkers() override; + + virtual void IterateActions(std::function&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + +private: + std::filesystem::path m_RecordingLogDir; + BasicFile m_WorkerDataFile; + BasicFile m_ActionDataFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + + // ChunkResolver interface + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + + struct ActionEntry + { + IoHash ActionId; + uint64_t Offset; + uint64_t Size; + }; + + std::vector m_Actions; + + void ScanActions(); +}; + +////////////////////////////////////////////////////////////////////////// + +struct LocalResolver : public ChunkResolver +{ + LocalResolver(const LocalResolver&) = delete; + LocalResolver& operator=(const LocalResolver&) = delete; + + LocalResolver() = default; + ~LocalResolver() = default; + + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + void Add(const IoHash& Cid, IoBuffer Data); + +private: + RwLock MapLock; + std::unordered_map Attachments; +}; + +/** + * This is a reader for UE/DDB recordings, which have a different layout on + * disk (no shared chunk store) + */ +class UeRecordingReader : public RecordingReaderBase, public ChunkResolver +{ +public: + explicit UeRecordingReader(const std::filesystem::path& RecordingPath); + ~UeRecordingReader(); + + virtual std::unordered_map ReadWorkers() override; + virtual void IterateActions(std::function&& Callback, + int TargetParallelism) override; + virtual size_t GetActionCount() const override; + virtual IoBuffer FindChunkByCid(const IoHash& DecompressedId) override; + +private: + std::filesystem::path m_RecordingDir; + LocalResolver m_LocalResolver; + std::vector m_WorkDirs; + + CbPackage ReadAction(std::filesystem::path WorkDir); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h new file mode 100644 index 000000000..6dc32eeea --- /dev/null +++ b/src/zencompute/include/zencompute/zencompute.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen { + +void zencompute_forcelinktests(); + +} diff --git a/src/zencompute/localrunner.cpp b/src/zencompute/localrunner.cpp new file mode 100644 index 000000000..9a27f3f3d --- /dev/null +++ b/src/zencompute/localrunner.cpp @@ -0,0 +1,722 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +namespace zen::compute { + +using namespace std::literals; + +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("local_exec")) +, m_ChunkResolver(Resolver) +, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) +, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +{ + SystemMetrics Sm = GetSystemMetricsForReporting(); + + m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); + + bool DidCleanup = false; + + if (std::filesystem::is_directory(m_ActionsPath)) + { + ZEN_INFO("Cleaning '{}'", m_ActionsPath); + + std::error_code Ec; + CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); + } + + DidCleanup = true; + } + + if (std::filesystem::is_directory(m_SandboxPath)) + { + ZEN_INFO("Cleaning '{}'", m_SandboxPath); + std::error_code Ec; + CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); + } + + DidCleanup = true; + } + + // We clean out all workers on startup since we can't know they are good. They could be bad + // due to tampering, malware (which I also mean to include AV and antimalware software) or + // other processes we have no control over + if (std::filesystem::is_directory(m_WorkerPath)) + { + ZEN_INFO("Cleaning '{}'", m_WorkerPath); + std::error_code Ec; + CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); + } + + DidCleanup = true; + } + + if (DidCleanup) + { + ZEN_INFO("Cleanup complete"); + } + + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; + +# if ZEN_PLATFORM_WINDOWS + // Suppress any error dialogs caused by missing dependencies + UINT OldMode = ::SetErrorMode(0); + ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); +# endif + + m_AcceptNewActions = true; +} + +LocalProcessRunner::~LocalProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); + } +} + +void +LocalProcessRunner::Shutdown() +{ + m_AcceptNewActions = false; + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } + + CancelRunningActions(); +} + +std::filesystem::path +LocalProcessRunner::CreateNewSandbox() +{ + std::string UniqueId = std::to_string(++m_SandboxCounter); + std::filesystem::path Path = m_SandboxPath / UniqueId; + zen::CreateDirectories(Path); + + return Path; +} + +void +LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + if (m_DumpActions) + { + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + const IoHash& WorkerId = WorkerPackage.GetObjectHash(); + + std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); + + ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { + std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); + zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); + }); + + ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); + } +} + +size_t +LocalProcessRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return 0; + } + + size_t RunningCount = m_RunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector +LocalProcessRunner::SubmitActions(const std::vector>& Actions) +{ + std::vector Results; + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +SubmitResult +LocalProcessRunner::SubmitAction(Ref Action) +{ + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return SubmitResult{.IsAccepted = false}; + } + + if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + std::filesystem::path SandboxPath = CreateNewSandbox(); + + CbPackage WorkerPackage = Action->Worker.Descriptor; + + std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); + + // Write out action + + zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + + // Manifest inputs in sandbox + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); + } + + zen::WriteFile(FilePath, DataBuffer); + }); + +# if ZEN_PLATFORM_WINDOWS + // Set up environment variables + + StringBuilder<1024> EnvironmentBlock; + + CbObject WorkerDescription = WorkerPackage.GetObject(); + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + PROCESS_INFORMATION ProcessInformation{}; + + ZEN_DEBUG("Executing: {}", WideToUtf8(CommandLine.c_str())); + + CommandLine.EnsureNulTerminated(); + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), // Environment block + SandboxPath.c_str(), // Current directory + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + // TODO: this is probably not the best way to report failure. The return + // object should include a failure state and context + + zen::ThrowLastError("Unable to launch process" /* TODO: Add context */); + } + + CloseHandle(ProcessInformation.hThread); + + Ref NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); +# else + ZEN_UNUSED(ActionId); + + ZEN_NOT_IMPLEMENTED(); + + int ExitCode = 0; +# endif + + return SubmitResult{.IsAccepted = true}; +} + +size_t +LocalProcessRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RunningMap.size(); +} + +std::filesystem::path +LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); + + if (!std::filesystem::exists(WorkerDir)) + { + _.ReleaseNow(); + + RwLock::ExclusiveLockScope $(m_WorkerLock); + + if (!std::filesystem::exists(WorkerDir)) + { + ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); + } + } + + return WorkerDir; +} + +void +LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function& ChunkReferenceCallback) +{ + std::string_view Name = FileEntry["name"sv].AsString(); + const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + CompressedBuffer Compressed; + + if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) + { + Compressed = Attachment->AsCompressedBinary(); + } + else + { + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); + } + + uint64_t DataRawSize = 0; + IoHash DataRawHash; + Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); + + if (DataRawSize != Size) + { + throw std::runtime_error( + fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); + } + } + + ChunkReferenceCallback(ChunkHash, Compressed); + + std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + + SharedBuffer Decompressed = Compressed.Decompress(); + zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); +} + +void +LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function&& ChunkReferenceCallback) +{ + CbObject WorkerDescription = WorkerPackage.GetObject(); + + // Manifest worker in Sandbox + + for (auto& It : WorkerDescription["executables"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + std::string_view Name = It.AsString(); + std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + zen::CreateDirectories(DirPath); + } + + for (auto& It : WorkerDescription["files"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); +} + +CbPackage +LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) +{ + std::filesystem::path OutputFile = SandboxPath / "build.output"; + FileContents OutputData = zen::ReadFile(OutputFile); + + if (OutputData.ErrorCode) + { + throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); + } + + CbPackage OutputPackage; + CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); + + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalRawAttachmentBytes = 0; + + Output.IterateAttachments([&](CbFieldView Field) { + IoHash Hash = Field.AsHash(); + std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; + FileContents ChunkData = zen::ReadFile(OutputPath); + + if (ChunkData.ErrorCode) + { + throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); + } + + uint64_t ChunkDataRawSize = 0; + IoHash ChunkDataHash; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); + + if (!AttachmentBuffer) + { + throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); + } + + TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + TotalRawAttachmentBytes += ChunkDataRawSize; + + CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); + OutputPackage.AddAttachment(Attachment); + }); + + OutputPackage.SetObject(Output); + + ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", + OutputPackage.GetAttachments().size(), + NiceBytes(TotalAttachmentBytes), + NiceBytes(TotalRawAttachmentBytes)); + + return OutputPackage; +} + +void +LocalProcessRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("LocalProcessRunner_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + // On Windows it's possible to wait on process handles, so we wait for either a process to exit + // or for the monitor event to be signaled (which indicates we should check for cancellation + // or shutdown). This could be further improved by using a completion port and registering process + // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing + // with an enormous number of concurrent processes. + // + // On other platforms we just wait on the monitor event and poll for process exits at intervals. +# if ZEN_PLATFORM_WINDOWS + auto WaitOnce = [&] { + HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; + + uint32_t NumHandles = 0; + + WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); + + m_RunningLock.WithSharedLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) + { + Ref Action = It->second; + + WaitHandles[NumHandles++] = Action->ProcessHandle; + } + }); + + DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); + + // return true if a handle was signaled + return (WaitResult <= NumHandles); + }; +# else + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; +# endif + + while (!WaitOnce()) + { + if (m_MonitorThreadEnabled == false) + { + return; + } + + SweepRunningActions(); + } + + // Signal received + + SweepRunningActions(); + } while (m_MonitorThreadEnabled); +} + +void +LocalProcessRunner::CancelRunningActions() +{ + Stopwatch Timer; + std::unordered_map> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + std::vector TerminatedLsnList; + + for (const auto& Kv : RunningMap) + { + Ref Action = Kv.second; + + // Terminate running process + +# if ZEN_PLATFORM_WINDOWS + BOOL Success = TerminateProcess(Action->ProcessHandle, 222); + + if (Success) + { + TerminatedLsnList.push_back(Kv.first); + } + else + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Action->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } +# else + ZEN_NOT_IMPLEMENTED("need to implement process termination"); +# endif + } + + // We only post results for processes we have terminated, in order + // to avoid multiple results getting posted for the same action + + for (int Lsn : TerminatedLsnList) + { + if (auto It = RunningMap.find(Lsn); It != RunningMap.end()) + { + Ref Running = It->second; + +# if ZEN_PLATFORM_WINDOWS + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + } +# endif + + // Clean up and post error result + + DeleteDirectories(Running->SandboxPath); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", TerminatedLsnList.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +void +LocalProcessRunner::SweepRunningActions() +{ + std::vector> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + // TODO: It would be good to not hold the exclusive lock while making + // system calls and other expensive operations. + + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref Action = It->second; + +# if ZEN_PLATFORM_WINDOWS + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Action->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Action->ProcessHandle); + Action->ProcessHandle = INVALID_HANDLE_VALUE; + + CompletedActions.push_back(std::move(Action)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } +# else + // TODO: implement properly for Mac/Linux + + ZEN_UNUSED(Action); +# endif + } + }); + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (Ref Running : CompletedActions) + { + const int ActionLsn = Running->Action->ActionLsn; + + if (Running->ExitCode == 0) + { + try + { + // Gather outputs + + CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); + + Running->Action->SetResult(std::move(OutputPackage)); + Running->Action->SetActionState(RunnerAction::State::Completed); + + // We can delete the files at this point + if (!DeleteDirectories(Running->SandboxPath)) + { + ZEN_WARN("Unable to delete directory '{}', this will continue to exist until service restart", Running->SandboxPath); + } + + // Success -- continue with next iteration of the loop + continue; + } + catch (std::exception& Ex) + { + ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + } + } + + // Failed - for now this is indicated with an empty package in + // the results map. We can clean out the sandbox directory immediately. + + std::error_code Ec; + DeleteDirectories(Running->SandboxPath, Ec); + + if (Ec) + { + ZEN_WARN("Unable to delete sandbox directory '{}': {}", Running->SandboxPath, Ec.message()); + } + + Running->Action->SetActionState(RunnerAction::State::Failed); + } +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/localrunner.h b/src/zencompute/localrunner.h new file mode 100644 index 000000000..35f464805 --- /dev/null +++ b/src/zencompute/localrunner.h @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/functionservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include +# include +# include +# include +# include + +# include +# include +# include + +namespace zen { +class CbPackage; +} + +namespace zen::compute { + +/** Direct process spawner + + This runner simply sets up a directory structure for each job and + creates a process to perform the computation in it. It is not very + efficient and is intended mostly for testing. + + */ + +class LocalProcessRunner : public FunctionRunner +{ + LocalProcessRunner(LocalProcessRunner&&) = delete; + LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; + +public: + LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir); + ~LocalProcessRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; + [[nodiscard]] virtual bool IsHealthy() override { return true; } + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; + +protected: + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + + struct RunningAction : public RefCounted + { + Ref Action; + void* ProcessHandle = nullptr; + int ExitCode = 0; + std::filesystem::path SandboxPath; + }; + + std::atomic_bool m_AcceptNewActions; + ChunkResolver& m_ChunkResolver; + RwLock m_WorkerLock; + std::filesystem::path m_WorkerPath; + std::atomic m_SandboxCounter = 0; + std::filesystem::path m_SandboxPath; + int32_t m_MaxRunningActions = 64; // arbitrary limit for testing + + // if used in conjuction with m_ResultsLock, this lock must be taken *after* + // m_ResultsLock to avoid deadlocks + RwLock m_RunningLock; + std::unordered_map> m_RunningMap; + + std::thread m_MonitorThread; + std::atomic m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + void SweepRunningActions(); + void CancelRunningActions(); + + std::filesystem::path CreateNewSandbox(); + void ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function&& ChunkReferenceCallback); + std::filesystem::path ManifestWorker(const WorkerDesc& Worker); + CbPackage GatherActionOutputs(std::filesystem::path SandboxPath); + + void DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function& ChunkReferenceCallback); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/recordingreader.cpp b/src/zencompute/recordingreader.cpp new file mode 100644 index 000000000..1c1a119cf --- /dev/null +++ b/src/zencompute/recordingreader.cpp @@ -0,0 +1,335 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/recordingreader.h" + +#include +#include +#include +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_PLATFORM_WINDOWS +# define ZEN_BUILD_ACTION L"Build.action" +# define ZEN_WORKER_UCB L"worker.ucb" +# else +# define ZEN_BUILD_ACTION "Build.action" +# define ZEN_WORKER_UCB "worker.ucb" +# endif + +////////////////////////////////////////////////////////////////////////// + +struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor +{ + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) + { + ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick); + + if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0) + { + WorkDirs.push_back(Parent); + } + else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0) + { + WorkerDirs.push_back(Parent); + } + } + + virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes) + { + ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes); + + return true; + } + + std::vector WorkerDirs; + std::vector WorkDirs; +}; + +////////////////////////////////////////////////////////////////////////// + +void +IterateOverArray(auto Array, auto Func, int TargetParallelism) +{ +# if ZEN_CONCRT_AVAILABLE + if (TargetParallelism > 1) + { + concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism); + concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); }); + + return; + } +# else + ZEN_UNUSED(TargetParallelism); +# endif + + for (const auto& Item : Array) + { + Func(Item); + } +} + +////////////////////////////////////////////////////////////////////////// + +RecordingReaderBase::~RecordingReaderBase() = default; + +////////////////////////////////////////////////////////////////////////// + +RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath) +{ + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +RecordingReader::~RecordingReader() +{ + m_CidStore.Flush(); +} + +size_t +RecordingReader::GetActionCount() const +{ + return m_Actions.size(); +} + +IoBuffer +RecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId)) + { + return Chunk; + } + + ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId); + + return {}; +} + +std::unordered_map +RecordingReader::ReadWorkers() +{ + std::unordered_map WorkerMap; + + { + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc"); + CbObject Toc = TocFile.Object; + + m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView Entry = It.AsArrayView(); + CbFieldViewIterator Vit = Entry.CreateViewIterator(); + + const IoHash WorkerId = Vit++->AsHash(); + const uint64_t Offset = Vit++->AsInt64(0); + const uint64_t Size = Vit++->AsInt64(0); + + IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size); + CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange); + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + } + }); + } + } + + // Scan actions as well (this should be called separately, ideally) + + ScanActions(); + + return WorkerMap; +} + +void +RecordingReader::ScanActions() +{ + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc"); + CbObject Toc = TocFile.Object; + + m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView ArrayEntry = It.AsArrayView(); + CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator(); + + ActionEntry Entry; + Entry.ActionId = Vit++->AsHash(); + Entry.Offset = Vit++->AsInt64(0); + Entry.Size = Vit++->AsInt64(0); + + m_Actions.push_back(Entry); + } +} + +void +RecordingReader::IterateActions(std::function&& Callback, int TargetParallelism) +{ + IterateOverArray( + m_Actions, + [&](const ActionEntry& Entry) { + CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size)); + + Callback(ActionDesc, Entry.ActionId); + }, + TargetParallelism); +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer +LocalResolver::FindChunkByCid(const IoHash& DecompressedId) +{ + RwLock::SharedLockScope _(MapLock); + if (auto It = Attachments.find(DecompressedId); It != Attachments.end()) + { + return It->second; + } + + return {}; +} + +void +LocalResolver::Add(const IoHash& Cid, IoBuffer Data) +{ + RwLock::ExclusiveLockScope _(MapLock); + Data.SetContentType(ZenContentType::kCompressedBinary); + Attachments[Cid] = Data; +} + +/// + +UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath) +{ +} + +UeRecordingReader::~UeRecordingReader() +{ +} + +size_t +UeRecordingReader::GetActionCount() const +{ + return m_WorkDirs.size(); +} + +IoBuffer +UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + return m_LocalResolver.FindChunkByCid(DecompressedId); +} + +std::unordered_map +UeRecordingReader::ReadWorkers() +{ + std::unordered_map WorkerMap; + + FileSystemTraversal Traversal; + RecordingTreeVisitor Visitor; + Traversal.TraverseFileSystem(m_RecordingDir, Visitor); + + m_WorkDirs = std::move(Visitor.WorkDirs); + + for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs) + { + CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb"); + CbObject WorkerDesc = WorkerFile.Object; + const IoHash& WorkerId = WorkerFile.Hash; + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten(); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + } + + return WorkerMap; +} + +void +UeRecordingReader::IterateActions(std::function&& Callback, int ParallelismTarget) +{ + IterateOverArray( + m_WorkDirs, + [&](const std::filesystem::path& WorkDir) { + CbPackage WorkPackage = ReadAction(WorkDir); + CbObject ActionObject = WorkPackage.GetObject(); + const IoHash& ActionId = WorkPackage.GetObjectHash(); + + Callback(ActionObject, ActionId); + }, + ParallelismTarget); +} + +CbPackage +UeRecordingReader::ReadAction(std::filesystem::path WorkDir) +{ + CbPackage WorkPackage; + std::filesystem::path WorkDescPath = WorkDir / "Build.action"; + CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath); + CbObject& ActionObject = ActionFile.Object; + + WorkPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten(); + + m_LocalResolver.Add(AttachmentCid, AttachmentData); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + ZEN_ASSERT(AttachmentCid == RawHash); + WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + + return WorkPackage; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/remotehttprunner.cpp new file mode 100644 index 000000000..98ced5fe8 --- /dev/null +++ b/src/zencompute/remotehttprunner.cpp @@ -0,0 +1,457 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remotehttprunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +////////////////////////////////////////////////////////////////////////// + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("http_exec")) +, m_ChunkResolver{InChunkResolver} +, m_BaseUrl{fmt::format("{}/apply", HostName)} +, m_Http(m_BaseUrl) +{ + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; +} + +RemoteHttpRunner::~RemoteHttpRunner() +{ + Shutdown(); +} + +void +RemoteHttpRunner::Shutdown() +{ + // TODO: should cleanly drain/cancel pending work + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +void +RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + CbPackage WorkerDesc = WorkerPackage; + + std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); + + HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); + + if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) + { + HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); + + if (DescResponse.StatusCode == HttpResponseCode::NotFound) + { + CbPackage Pkg = WorkerDesc; + + // Build response package by sending only the attachments + // the other end needs. We start with the full package and + // remove the attachments which are not needed. + + { + std::unordered_set Needed; + + CbObject Response = DescResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + Needed.insert(NeedHash); + } + + std::unordered_set ToRemove; + + for (const CbAttachment& Attachment : Pkg.GetAttachments()) + { + const IoHash& Hash = Attachment.GetHash(); + + if (Needed.find(Hash) == Needed.end()) + { + ToRemove.insert(Hash); + } + } + + for (const IoHash& Hash : ToRemove) + { + int RemovedCount = Pkg.RemoveAttachment(Hash); + + ZEN_ASSERT(RemovedCount == 1); + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); + + if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + } + else if (!IsHttpSuccessCode(DescResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + else + { + ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); + } + } + else if (WorkerResponse.StatusCode == HttpResponseCode::OK) + { + // Already known from a previous run + } + else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", + WorkerId, + m_Http.GetBaseUri(), + WorkerUrl, + (int)WorkerResponse.StatusCode, + ToString(WorkerResponse.StatusCode)); + + // TODO: propagate error + } +} + +size_t +RemoteHttpRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + size_t RunningCount = m_RemoteRunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector +RemoteHttpRunner::SubmitActions(const std::vector>& Actions) +{ + std::vector Results; + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +SubmitResult +RemoteHttpRunner::SubmitAction(Ref Action) +{ + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + // Enqueue job + + CbObject Result; + + HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj); + HttpResponseCode WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (WorkResponseCode == HttpResponseCode::NotFound) + { + // Not all attachments are present + + // Build response package including all required attachments + + CbPackage Pkg; + Pkg.SetObject(ActionObj); + + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + // No such attachment + + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri()); + + // TODO: include more information about the failure in the response + + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (PayloadResponse.StatusCode == HttpResponseCode::OK) + { + Result = PayloadResponse.AsObject(); + } + else + { + // Unexpected response + + const int ResponseStatusCode = (int)PayloadResponse.StatusCode; + + ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})", + ActionId, + m_Http.GetBaseUri(), + ResponseStatusCode, + ToString(ResponseStatusCode)); + + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}/jobs", + ResponseStatusCode, + ToString(ResponseStatusCode), + m_Http.GetBaseUri())}; + } + } + + if (Result) + { + if (const int32_t LsnField = Result["lsn"].AsInt32(0)) + { + HttpRunningAction NewAction; + NewAction.Action = Action; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; + } + } + + return {}; +} + +bool +RemoteHttpRunner::IsHealthy() +{ + if (HttpClient::Response Ready = m_Http.Get("/ready")) + { + return true; + } + else + { + // TODO: use response to propagate context + return false; + } +} + +size_t +RemoteHttpRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RemoteRunningMap.size(); +} + +void +RemoteHttpRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("RemoteHttpRunner_Monitor"); + + do + { + const int NormalWaitingTime = 1000; + int WaitTimeMs = NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { + const size_t RetiredCount = SweepRunningActions(); + + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) + { + WaitTimeMs = NormalWaitingTime / 4; + } + else + { + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } + } + }); + }; + + while (!WaitOnce()) + { + SweepOnce(); + } + + // Signal received - this may mean we should quit + + SweepOnce(); + } while (m_MonitorThreadEnabled); +} + +size_t +RemoteHttpRunner::SweepRunningActions() +{ + std::vector CompletedActions; + + // Poll remote for list of completed actions + + HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); + + if (CbObject Completed = ResponseCompleted.AsObject()) + { + for (auto& FieldIt : Completed["completed"sv]) + { + const int32_t CompleteLsn = FieldIt.AsInt32(); + + if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn))) + { + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.ActionResults = ResponseJob.AsPackage(); + CompletedAction.Success = true; + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); + } + } + + if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) + { + // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) + if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) + { + const int32_t NewCap = zen::Max(4, CpuCount); + + if (m_MaxRunningActions > NewCap) + { + ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); + + m_MaxRunningActions = NewCap; + } + } + } + } + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (HttpRunningAction& HttpAction : CompletedActions) + { + const int ActionLsn = HttpAction.Action->ActionLsn; + + if (HttpAction.Success) + { + ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn); + + HttpAction.Action->SetActionState(RunnerAction::State::Completed); + + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); + } + else + { + HttpAction.Action->SetActionState(RunnerAction::State::Failed); + } + } + + return CompletedActions.size(); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/remotehttprunner.h b/src/zencompute/remotehttprunner.h new file mode 100644 index 000000000..1e885da3d --- /dev/null +++ b/src/zencompute/remotehttprunner.h @@ -0,0 +1,80 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/functionservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include +# include +# include +# include + +# include +# include +# include + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** HTTP-based runner + + This implements a DDC remote compute execution strategy via REST API + + */ + +class RemoteHttpRunner : public FunctionRunner +{ + RemoteHttpRunner(RemoteHttpRunner&&) = delete; + RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; + +public: + RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName); + ~RemoteHttpRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; + [[nodiscard]] virtual bool IsHealthy() override; + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; + +protected: + LoggerRef Log() { return m_Log; } + +private: + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + std::string m_BaseUrl; + HttpClient m_Http; + + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + + struct HttpRunningAction + { + Ref Action; + int RemoteActionLsn = 0; // Remote LSN + bool Success = false; + CbPackage ActionResults; + }; + + RwLock m_RunningLock; + std::unordered_map m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn + + std::thread m_MonitorThread; + std::atomic m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + size_t SweepRunningActions(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua new file mode 100644 index 000000000..c710b662d --- /dev/null +++ b/src/zencompute/xmake.lua @@ -0,0 +1,11 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zencompute') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") + add_packages("vcpkg::gsl-lite") + add_packages("vcpkg::spdlog", "vcpkg::cxxopts") diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp new file mode 100644 index 000000000..633250f4e --- /dev/null +++ b/src/zencompute/zencompute.cpp @@ -0,0 +1,12 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/zencompute.h" + +namespace zen { + +void +zencompute_forcelinktests() +{ +} + +} // namespace zen diff --git a/src/zennet/beacon.cpp b/src/zennet/beacon.cpp new file mode 100644 index 000000000..394a4afbb --- /dev/null +++ b/src/zennet/beacon.cpp @@ -0,0 +1,170 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct FsBeacon::Impl +{ + Impl(std::filesystem::path ShareRoot); + ~Impl(); + + void EnsureValid(); + + void AddGroup(std::string_view GroupId, CbObject Metadata); + void ScanGroup(std::string_view GroupId, std::vector& OutSessions); + void ReadMetadata(std::string_view GroupId, const std::vector& InSessions, std::vector& OutMetadata); + +private: + std::filesystem::path m_ShareRoot; + zen::Oid m_SessionId; + + struct GroupData + { + CbObject Metadata; + BasicFile LockFile; + }; + + std::map m_Registration; + + std::filesystem::path GetSessionMarkerPath(std::string_view GroupId, const Oid& SessionId) + { + Oid::String_t SessionIdString; + SessionId.ToString(SessionIdString); + + return m_ShareRoot / GroupId / SessionIdString; + } +}; + +FsBeacon::Impl::Impl(std::filesystem::path ShareRoot) : m_ShareRoot(ShareRoot), m_SessionId(GetSessionId()) +{ +} + +FsBeacon::Impl::~Impl() +{ +} + +void +FsBeacon::Impl::EnsureValid() +{ +} + +void +FsBeacon::Impl::AddGroup(std::string_view GroupId, CbObject Metadata) +{ + zen::CreateDirectories(m_ShareRoot / GroupId); + std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, m_SessionId); + + GroupData& Group = m_Registration[std::string(GroupId)]; + + Group.Metadata = Metadata; + + std::error_code Ec; + Group.LockFile.Open(MarkerFile, + BasicFile::Mode::kTruncate | BasicFile::Mode::kPreventDelete | + BasicFile::Mode::kPreventWrite /* | BasicFile::Mode::kDeleteOnClose */, + Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open beacon marker file '{}' for write", MarkerFile)); + } + + Group.LockFile.WriteAll(Metadata.GetBuffer().AsIoBuffer(), Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to write to beacon marker file '{}'", MarkerFile)); + } + + Group.LockFile.Flush(); +} + +void +FsBeacon::Impl::ScanGroup(std::string_view GroupId, std::vector& OutSessions) +{ + DirectoryContent Dc; + zen::GetDirectoryContent(m_ShareRoot / GroupId, zen::DirectoryContentFlags::IncludeFiles, /* out */ Dc); + + for (const std::filesystem::path& FilePath : Dc.Files) + { + std::filesystem::path File = FilePath.filename(); + + std::error_code Ec; + if (std::filesystem::remove(FilePath, Ec) == false) + { + auto FileString = File.generic_string(); + + if (FileString.length() != Oid::StringLength) + continue; + + if (const Oid SessionId = Oid::FromHexString(FileString)) + { + if (std::filesystem::file_size(File, Ec) > 0) + { + OutSessions.push_back(SessionId); + } + } + } + } +} + +void +FsBeacon::Impl::ReadMetadata(std::string_view GroupId, const std::vector& InSessions, std::vector& OutMetadata) +{ + for (const Oid& SessionId : InSessions) + { + const std::filesystem::path MarkerFile = GetSessionMarkerPath(GroupId, SessionId); + + if (CbObject Metadata = LoadCompactBinaryObject(MarkerFile).Object) + { + OutMetadata.push_back(std::move(Metadata)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +FsBeacon::FsBeacon(std::filesystem::path ShareRoot) : m_Impl(std::make_unique(ShareRoot)) +{ +} + +FsBeacon::~FsBeacon() +{ +} + +void +FsBeacon::AddGroup(std::string_view GroupId, CbObject Metadata) +{ + m_Impl->AddGroup(GroupId, Metadata); +} + +void +FsBeacon::ScanGroup(std::string_view GroupId, std::vector& OutSessions) +{ + m_Impl->ScanGroup(GroupId, OutSessions); +} + +void +FsBeacon::ReadMetadata(std::string_view GroupId, const std::vector& InSessions, std::vector& OutMetadata) +{ + m_Impl->ReadMetadata(GroupId, InSessions, OutMetadata); +} + +////////////////////////////////////////////////////////////////////////// + +} // namespace zen diff --git a/src/zennet/include/zennet/beacon.h b/src/zennet/include/zennet/beacon.h new file mode 100644 index 000000000..a8d4805cb --- /dev/null +++ b/src/zennet/include/zennet/beacon.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include +#include + +namespace zen { + +class CbObject; + +/** File-system based peer discovery + + Intended to be used with an SMB file share as the root. + */ + +class FsBeacon +{ +public: + FsBeacon(std::filesystem::path ShareRoot); + ~FsBeacon(); + + void AddGroup(std::string_view GroupId, CbObject Metadata); + void ScanGroup(std::string_view GroupId, std::vector& OutSessions); + void ReadMetadata(std::string_view GroupId, const std::vector& InSessions, std::vector& OutMetadata); + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen diff --git a/src/zennet/include/zennet/statsdclient.h b/src/zennet/include/zennet/statsdclient.h index c378e49ce..7688c132c 100644 --- a/src/zennet/include/zennet/statsdclient.h +++ b/src/zennet/include/zennet/statsdclient.h @@ -8,6 +8,8 @@ #include #include +#undef SendMessage + namespace zen { class StatsTransportBase diff --git a/src/zennet/statsdclient.cpp b/src/zennet/statsdclient.cpp index fe5ca4dda..a0e8cb6ce 100644 --- a/src/zennet/statsdclient.cpp +++ b/src/zennet/statsdclient.cpp @@ -12,6 +12,7 @@ ZEN_THIRD_PARTY_INCLUDES_START #include #include +#undef SendMessage ZEN_THIRD_PARTY_INCLUDES_END namespace zen { diff --git a/src/zenserver-test/function-tests.cpp b/src/zenserver-test/function-tests.cpp new file mode 100644 index 000000000..559387fa2 --- /dev/null +++ b/src/zenserver-test/function-tests.cpp @@ -0,0 +1,34 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_WITH_TESTS + +# include +# include +# include +# include +# include + +# include "zenserver-test.h" + +namespace zen::tests { + +using namespace std::literals; + +TEST_CASE("function.run") +{ + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + Instance.SpawnServer(13337); + + ZEN_INFO("Waiting..."); + + Instance.WaitUntilReady(); +} + +} // namespace zen::tests + +#endif diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp new file mode 100644 index 000000000..173f56386 --- /dev/null +++ b/src/zenserver/compute/computeserver.cpp @@ -0,0 +1,330 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "computeserver.h" +#include +#include "computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +void +ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) +{ + Options.add_option("compute", + "", + "upstream-notification-endpoint", + "Endpoint URL for upstream notifications", + cxxopts::value(m_ServerOptions.UpstreamNotificationEndpoint)->default_value(""), + ""); + + Options.add_option("compute", + "", + "instance-id", + "Instance ID for use in notifications", + cxxopts::value(m_ServerOptions.InstanceId)->default_value(""), + ""); +} + +void +ZenComputeServerConfigurator::AddConfigOptions(LuaConfig::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenComputeServerConfigurator::ApplyOptions(cxxopts::Options& Options) +{ + ZEN_UNUSED(Options); +} + +void +ZenComputeServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) +{ + ZEN_UNUSED(LuaOptions); +} + +void +ZenComputeServerConfigurator::ValidateOptions() +{ +} + +/////////////////////////////////////////////////////////////////////////// + +ZenComputeServer::ZenComputeServer() +{ +} + +ZenComputeServer::~ZenComputeServer() +{ + Cleanup(); +} + +int +ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry) +{ + ZEN_TRACE_CPU("ZenComputeServer::Initialize"); + ZEN_MEMSCOPE(GetZenserverTag()); + + ZEN_INFO(ZEN_APP_NAME " initializing in HUB server mode"); + + const int EffectiveBasePort = ZenServerBase::Initialize(ServerConfig, ServerEntry); + if (EffectiveBasePort < 0) + { + return EffectiveBasePort; + } + + // This is a workaround to make sure we can have automated tests. Without + // this the ranges for different child zen hub processes could overlap with + // the main test range. + ZenServerEnvironment::SetBaseChildId(1000); + + m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; + + InitializeState(ServerConfig); + InitializeServices(ServerConfig); + RegisterServices(ServerConfig); + + ZenServerBase::Finalize(); + + return EffectiveBasePort; +} + +void +ZenComputeServer::Cleanup() +{ + ZEN_TRACE_CPU("ZenStorageServer::Cleanup"); + ZEN_INFO(ZEN_APP_NAME " cleaning up"); + try + { + m_IoContext.stop(); + if (m_IoRunner.joinable()) + { + m_IoRunner.join(); + } + + if (m_Http) + { + m_Http->Close(); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("exception thrown during Cleanup() in {}: '{}'", ZEN_APP_NAME, Ex.what()); + } +} + +void +ZenComputeServer::InitializeState(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_UNUSED(ServerConfig); +} + +void +ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_INFO("initializing storage"); + + CidStoreConfiguration Config; + Config.RootDirectory = m_DataRoot / "cas"; + + m_CidStore = std::make_unique(m_GcManager); + m_CidStore->Initialize(Config); + + ZEN_INFO("instantiating API service"); + m_ApiService = std::make_unique(*m_Http); + + ZEN_INFO("instantiating compute service"); + m_ComputeService = std::make_unique(ServerConfig.DataDir / "compute"); + + // Ref Runner; + // Runner = zen::compute::CreateLocalRunner(*m_CidStore, ServerConfig.DataDir / "runner"); + + // TODO: (re)implement default configuration here + + ZEN_INFO("instantiating function service"); + m_FunctionService = + std::make_unique(*m_CidStore, m_StatsService, ServerConfig.DataDir / "functions"); +} + +void +ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig) +{ + ZEN_UNUSED(ServerConfig); + + if (m_ComputeService) + { + m_Http->RegisterService(*m_ComputeService); + } + + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + + if (m_FunctionService) + { + m_Http->RegisterService(*m_FunctionService); + } +} + +void +ZenComputeServer::Run() +{ + if (m_ProcessMonitor.IsActive()) + { + CheckOwnerPid(); + } + + if (!m_TestMode) + { + // clang-format off + ZEN_INFO( R"(__________ _________ __ )" "\n" + R"(\____ /____ ____ \_ ___ \ ____ _____ ______ __ ___/ |_ ____ )" "\n" + R"( / // __ \ / \/ \ \/ / _ \ / \\____ \| | \ __\/ __ \ )" "\n" + R"( / /\ ___/| | \ \___( <_> ) Y Y \ |_> > | /| | \ ___/ )" "\n" + R"(/_______ \___ >___| /\______ /\____/|__|_| / __/|____/ |__| \___ >)" "\n" + R"( \/ \/ \/ \/ \/|__| \/ )"); + // clang-format on + + ExtendableStringBuilder<256> BuildOptions; + GetBuildOptions(BuildOptions, '\n'); + ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + } + + ZEN_INFO(ZEN_APP_NAME " now running as COMPUTE (pid: {})", GetCurrentProcessId()); + +# if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + ZEN_INFO("detected Wine session - " ZEN_APP_NAME " is not formally tested on Wine and may therefore not work or perform well"); + } +# endif + +# if ZEN_USE_SENTRY + ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED"); + if (m_UseSentry) + { + SentryIntegration::ClearCaches(); + } +# endif + + if (m_DebugOptionForcedCrash) + { + ZEN_DEBUG_BREAK(); + } + + const bool IsInteractiveMode = IsInteractiveSession(); // &&!m_TestMode; + + SetNewState(kRunning); + + OnReady(); + + m_Http->Run(IsInteractiveMode); + + SetNewState(kShuttingDown); + + ZEN_INFO(ZEN_APP_NAME " exiting"); +} + +////////////////////////////////////////////////////////////////////////////////// + +ZenComputeServerMain::ZenComputeServerMain(ZenComputeServerConfig& ServerOptions) +: ZenServerMain(ServerOptions) +, m_ServerOptions(ServerOptions) +{ +} + +void +ZenComputeServerMain::DoRun(ZenServerState::ZenServerEntry* Entry) +{ + ZenComputeServer Server; + Server.SetDataRoot(m_ServerOptions.DataDir); + Server.SetContentRoot(m_ServerOptions.ContentDir); + Server.SetTestMode(m_ServerOptions.IsTest); + Server.SetDedicatedMode(m_ServerOptions.IsDedicated); + + const int EffectiveBasePort = Server.Initialize(m_ServerOptions, Entry); + if (EffectiveBasePort == -1) + { + // Server.Initialize has already logged what the issue is - just exit with failure code here. + std::exit(1); + } + + Entry->EffectiveListenPort = uint16_t(EffectiveBasePort); + if (EffectiveBasePort != m_ServerOptions.BasePort) + { + ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort); + m_ServerOptions.BasePort = EffectiveBasePort; + } + + std::unique_ptr ShutdownThread; + std::unique_ptr ShutdownEvent; + + ExtendableStringBuilder<64> ShutdownEventName; + ShutdownEventName << "Zen_" << m_ServerOptions.BasePort << "_Shutdown"; + ShutdownEvent.reset(new NamedEvent{ShutdownEventName}); + + // Monitor shutdown signals + + ShutdownThread.reset(new std::thread{[&] { + SetCurrentThreadName("shutdown_mon"); + + ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}' for process {}", ShutdownEventName, zen::GetCurrentProcessId()); + + if (ShutdownEvent->Wait()) + { + ZEN_INFO("shutdown signal for pid {} received", zen::GetCurrentProcessId()); + Server.RequestExit(0); + } + else + { + ZEN_INFO("shutdown signal wait() failed"); + } + }}); + + auto CleanupShutdown = MakeGuard([&ShutdownEvent, &ShutdownThread] { + ReportServiceStatus(ServiceStatus::Stopping); + + if (ShutdownEvent) + { + ShutdownEvent->Set(); + } + if (ShutdownThread && ShutdownThread->joinable()) + { + ShutdownThread->join(); + } + }); + + // If we have a parent process, establish the mechanisms we need + // to be able to communicate readiness with the parent + + Server.SetIsReadyFunc([&] { + std::error_code Ec; + m_LockFile.Update(MakeLockData(true), Ec); + ReportServiceStatus(ServiceStatus::Running); + NotifyReady(); + }); + + Server.Run(); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h new file mode 100644 index 000000000..625140b23 --- /dev/null +++ b/src/zenserver/compute/computeserver.h @@ -0,0 +1,106 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenserver.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include + +namespace cxxopts { +class Options; +} +namespace zen::LuaConfig { +struct Options; +} + +namespace zen::compute { +class HttpFunctionService; +} + +namespace zen { + +class CidStore; +class HttpApiService; +class HttpComputeService; + +struct ZenComputeServerConfig : public ZenServerConfig +{ + std::string UpstreamNotificationEndpoint; + std::string InstanceId; // For use in notifications +}; + +struct ZenComputeServerConfigurator : public ZenServerConfiguratorBase +{ + ZenComputeServerConfigurator(ZenComputeServerConfig& ServerOptions) + : ZenServerConfiguratorBase(ServerOptions) + , m_ServerOptions(ServerOptions) + { + } + + ~ZenComputeServerConfigurator() = default; + +private: + virtual void AddCliOptions(cxxopts::Options& Options) override; + virtual void AddConfigOptions(LuaConfig::Options& Options) override; + virtual void ApplyOptions(cxxopts::Options& Options) override; + virtual void OnConfigFileParsed(LuaConfig::Options& LuaOptions) override; + virtual void ValidateOptions() override; + + ZenComputeServerConfig& m_ServerOptions; +}; + +class ZenComputeServerMain : public ZenServerMain +{ +public: + ZenComputeServerMain(ZenComputeServerConfig& ServerOptions); + virtual void DoRun(ZenServerState::ZenServerEntry* Entry) override; + + ZenComputeServerMain(const ZenComputeServerMain&) = delete; + ZenComputeServerMain& operator=(const ZenComputeServerMain&) = delete; + + typedef ZenComputeServerConfig Config; + typedef ZenComputeServerConfigurator Configurator; + +private: + ZenComputeServerConfig& m_ServerOptions; +}; + +/** + * The compute server handles DDC build function execution requests + * only. It's intended to be used on a pure compute resource and does + * not handle any storage tasks. The actual scheduling happens upstream + * in a storage server instance. + */ + +class ZenComputeServer : public ZenServerBase +{ + ZenComputeServer& operator=(ZenComputeServer&&) = delete; + ZenComputeServer(ZenComputeServer&&) = delete; + +public: + ZenComputeServer(); + ~ZenComputeServer(); + + int Initialize(const ZenComputeServerConfig& ServerConfig, ZenServerState::ZenServerEntry* ServerEntry); + void Run(); + void Cleanup(); + +private: + HttpStatsService m_StatsService; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr m_CidStore; + std::unique_ptr m_ComputeService; + std::unique_ptr m_ApiService; + std::unique_ptr m_FunctionService; + + void InitializeState(const ZenComputeServerConfig& ServerConfig); + void InitializeServices(const ZenComputeServerConfig& ServerConfig); + void RegisterServices(const ZenComputeServerConfig& ServerConfig); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeservice.cpp b/src/zenserver/compute/computeservice.cpp new file mode 100644 index 000000000..2c0bc0ae9 --- /dev/null +++ b/src/zenserver/compute/computeservice.cpp @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +# include +ZEN_THIRD_PARTY_INCLUDES_END + +# include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct ResourceMetrics +{ + uint64_t DiskUsageBytes = 0; + uint64_t MemoryUsageBytes = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + Impl(const Impl&) = delete; + Impl& operator=(const Impl&) = delete; + + Impl(); + ~Impl(); + + void Initialize(std::filesystem::path BaseDir) { ZEN_UNUSED(BaseDir); } + + void Cleanup() {} + +private: +}; + +HttpComputeService::Impl::Impl() +{ +} + +HttpComputeService::Impl::~Impl() +{ +} + +/////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(std::filesystem::path BaseDir) : m_Impl(std::make_unique()) +{ + using namespace std::literals; + + m_Impl->Initialize(BaseDir); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.BeginArray("modules"); + Obj.EndArray(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "stats", + [this](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); +} + +HttpComputeService::~HttpComputeService() +{ +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeservice.h b/src/zenserver/compute/computeservice.h new file mode 100644 index 000000000..339200dd8 --- /dev/null +++ b/src/zenserver/compute/computeservice.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES +namespace zen { + +/** ZenServer Compute Service + * + * Manages a set of compute workers for use in UEFN content worker + * + */ +class HttpComputeService : public zen::HttpService +{ +public: + HttpComputeService(std::filesystem::path BaseDir); + ~HttpComputeService(); + + HttpComputeService(const HttpComputeService&) = delete; + HttpComputeService& operator=(const HttpComputeService&) = delete; + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + HttpRequestRouter m_Router; + + struct Impl; + + std::unique_ptr m_Impl; +}; + +} // namespace zen +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/frontend/html/compute.html b/src/zenserver/frontend/html/compute.html new file mode 100644 index 000000000..668189fe5 --- /dev/null +++ b/src/zenserver/frontend/html/compute.html @@ -0,0 +1,991 @@ + + + + + + Zen Compute Dashboard + + + + +
+
+
+

Zen Compute Dashboard

+
Last updated: Never
+
+
+
+ Checking... +
+
+ +
+ + +
Action Queue
+
+
+
Pending Actions
+
-
+
Waiting to be scheduled
+
+
+
Running Actions
+
-
+
Currently executing
+
+
+
Completed Actions
+
-
+
Results available
+
+
+ + +
+
Action Queue History
+
+ +
+
+ + +
Performance Metrics
+
+
Completion Rate
+
+
+
-
+
1 min rate
+
+
+
-
+
5 min rate
+
+
+
-
+
15 min rate
+
+
+
+
+ Total Retired + - +
+
+ Mean Rate + - +
+
+
+ + +
Workers
+
+
Worker Status
+
+ Registered Workers + - +
+ +
+ + +
Recent Actions
+
+
Action History
+
No actions recorded yet.
+ +
+ + +
System Resources
+
+
+
CPU Usage
+
-
+
Percent
+
+
+
+
+ +
+
+
+ Packages + - +
+
+ Physical Cores + - +
+
+ Logical Processors + - +
+
+
+
+
Memory
+
+ Used + - +
+
+ Total + - +
+
+
+
+
+
+
Disk
+
+ Used + - +
+
+ Total + - +
+
+
+
+
+
+
+ + + + diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 1a929b026..ee783d2a6 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -23,6 +23,9 @@ #include #include "diag/logging.h" + +#include "compute/computeserver.h" + #include "storage/storageconfig.h" #include "storage/zenstorageserver.h" @@ -61,11 +64,19 @@ namespace zen { #if ZEN_PLATFORM_WINDOWS -template +/** Windows Service wrapper for Zen servers + * + * This class wraps a Zen server main entry point (the Main template parameter) + * into a Windows Service by implementing the WindowsService interface. + * + * The Main type needs to implement the virtual functions from the ZenServerMain + * base class, which provides the actual server logic. + */ +template class ZenWindowsService : public WindowsService { public: - ZenWindowsService(typename T::Config& ServerOptions) : m_EntryPoint(ServerOptions) {} + ZenWindowsService(typename Main::Config& ServerOptions) : m_EntryPoint(ServerOptions) {} ZenWindowsService(const ZenWindowsService&) = delete; ZenWindowsService& operator=(const ZenWindowsService&) = delete; @@ -73,7 +84,7 @@ public: virtual int Run() override { return m_EntryPoint.Run(); } private: - T m_EntryPoint; + Main m_EntryPoint; }; #endif // ZEN_PLATFORM_WINDOWS @@ -84,6 +95,23 @@ private: namespace zen { +/** Application main entry point template + * + * This function handles common application startup tasks while allowing + * different server types to be plugged in via the Main template parameter. + * + * On Windows, this function also handles platform-specific service + * installation and uninstallation. + * + * The Main type needs to implement the virtual functions from the ZenServerMain + * base class, which provides the actual server logic. + * + * The Main type is also expected to provide the following members: + * + * typedef Config -- Server configuration type, derived from ZenServerConfig + * typedef Configurator -- Server configuration handler type, implements ZenServerConfiguratorBase + * + */ template int AppMain(int argc, char* argv[]) @@ -241,7 +269,12 @@ main(int argc, char* argv[]) auto _ = zen::MakeGuard([] { // Allow some time for worker threads to unravel, in an effort - // to prevent shutdown races in TLS object destruction + // to prevent shutdown races in TLS object destruction, mainly due to + // threads which we don't directly control (Windows thread pool) and + // therefore can't join. + // + // This isn't a great solution, but for now it seems to help reduce + // shutdown crashes observed in some situations. WaitForThreads(1000); }); @@ -249,6 +282,7 @@ main(int argc, char* argv[]) { kHub, kStore, + kCompute, kTest } ServerMode = kStore; @@ -258,10 +292,14 @@ main(int argc, char* argv[]) { ServerMode = kHub; } - else if (argv[1] == "store"sv) + else if ((argv[1] == "store"sv) || (argv[1] == "storage"sv)) { ServerMode = kStore; } + else if (argv[1] == "compute"sv) + { + ServerMode = kCompute; + } else if (argv[1] == "test"sv) { ServerMode = kTest; @@ -280,6 +318,13 @@ main(int argc, char* argv[]) break; case kHub: return AppMain(argc, argv); + case kCompute: +#if ZEN_WITH_COMPUTE_SERVICES + return AppMain(argc, argv); +#else + fprintf(stderr, "compute services are not compiled in!\n"); + exit(5); +#endif default: case kStore: return AppMain(argc, argv); diff --git a/src/zenserver/storage/storageconfig.cpp b/src/zenserver/storage/storageconfig.cpp index 0f8ab1e98..089b6b572 100644 --- a/src/zenserver/storage/storageconfig.cpp +++ b/src/zenserver/storage/storageconfig.cpp @@ -797,6 +797,7 @@ ZenStorageServerCmdLineOptions::AddCacheOptions(cxxopts::Options& options, ZenSt cxxopts::value(ServerOptions.StructuredCacheConfig.MemMaxAgeSeconds)->default_value("86400"), ""); + options.add_option("compute", "", "lie-cpus", "Lie to upstream about CPU capabilities", cxxopts::value(ServerOptions.LieCpu), ""); options.add_option("cache", "", "cache-bucket-maxblocksize", diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h index d59d05cf6..b408b0c26 100644 --- a/src/zenserver/storage/storageconfig.h +++ b/src/zenserver/storage/storageconfig.h @@ -156,6 +156,7 @@ struct ZenStorageServerConfig : public ZenServerConfig ZenWorkspacesConfig WorksSpacesConfig; std::filesystem::path PluginsConfigFile; // Path to plugins config file bool ObjectStoreEnabled = false; + bool ComputeEnabled = true; std::string ScrubOptions; }; diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index 2b74395c3..ff854b72d 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -182,6 +182,13 @@ ZenStorageServer::RegisterServices() #endif // ZEN_WITH_VFS m_Http->RegisterService(*m_AdminService); + +#if ZEN_WITH_COMPUTE_SERVICES + if (m_HttpFunctionService) + { + m_Http->RegisterService(*m_HttpFunctionService); + } +#endif } void @@ -267,6 +274,16 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions m_BuildStoreService = std::make_unique(m_StatusService, m_StatsService, *m_BuildStore); } +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeEnabled) + { + ZEN_OTEL_SPAN("InitializeComputeService"); + + m_HttpFunctionService = + std::make_unique(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); + } +#endif + #if ZEN_WITH_VFS m_VfsServiceImpl = std::make_unique(); m_VfsServiceImpl->AddService(Ref(m_ProjectStore)); @@ -805,6 +822,10 @@ ZenStorageServer::Cleanup() Flush(); +#if ZEN_WITH_COMPUTE_SERVICES + m_HttpFunctionService.reset(); +#endif + m_AdminService.reset(); m_VfsService.reset(); m_VfsServiceImpl.reset(); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index 5ccb587d6..456447a2a 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,6 +24,10 @@ #include "vfs/vfsservice.h" #include "workspaces/httpworkspaces.h" +#if ZEN_WITH_COMPUTE_SERVICES +# include +#endif + namespace zen { class ZenStorageServer : public ZenServerBase @@ -34,11 +39,6 @@ public: ZenStorageServer(); ~ZenStorageServer(); - void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } - void SetTestMode(bool State) { m_TestMode = State; } - void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } - void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } - int Initialize(const ZenStorageServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry); void Run(); void Cleanup(); @@ -48,14 +48,9 @@ private: void InitializeStructuredCache(const ZenStorageServerConfig& ServerOptions); void Flush(); - bool m_IsDedicatedMode = false; - bool m_TestMode = false; - bool m_DebugOptionForcedCrash = false; - std::string m_StartupScrubOptions; - CbObject m_RootManifest; - std::filesystem::path m_DataRoot; - std::filesystem::path m_ContentRoot; - asio::steady_timer m_StateMarkerTimer{m_IoContext}; + std::string m_StartupScrubOptions; + CbObject m_RootManifest; + asio::steady_timer m_StateMarkerTimer{m_IoContext}; void EnqueueStateMarkerTimer(); void CheckStateMarker(); @@ -95,6 +90,11 @@ private: std::unique_ptr m_BuildStoreService; std::unique_ptr m_VfsService; std::unique_ptr m_AdminService; + std::unique_ptr m_ApiService; + +#if ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr m_HttpFunctionService; +#endif }; struct ZenStorageServerConfigurator; diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 6ee80dc62..9ab51beb2 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -2,7 +2,11 @@ target("zenserver") set_kind("binary") + if enable_unity then + add_rules("c++.unity_build", {batchsize = 4}) + end add_deps("zencore", + "zencompute", "zenhttp", "zennet", "zenremotestore", diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 7f9bf56a9..7bf6126df 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -145,6 +146,13 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: InitializeSecuritySettings(ServerOptions); + if (ServerOptions.LieCpu) + { + SetCpuCountForReporting(ServerOptions.LieCpu); + + ZEN_INFO("Reporting concurrency: {}", ServerOptions.LieCpu); + } + m_StatusService.RegisterHandler("status", *this); m_Http->RegisterService(m_StatusService); diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index efa46f361..5a8a079c0 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -43,6 +43,11 @@ public: void SetIsReadyFunc(std::function&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); } + void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } + void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } + void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } + void SetTestMode(bool State) { m_TestMode = State; } + protected: int Initialize(const ZenServerConfig& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry); void Finalize(); @@ -55,6 +60,10 @@ protected: bool m_UseSentry = false; bool m_IsPowerCycle = false; + bool m_IsDedicatedMode = false; + bool m_TestMode = false; + bool m_DebugOptionForcedCrash = false; + std::thread m_IoRunner; asio::io_context m_IoContext; void EnsureIoRunner(); @@ -72,6 +81,9 @@ protected: std::function m_IsReadyFunc; void OnReady(); + std::filesystem::path m_DataRoot; // Root directory for server state + std::filesystem::path m_ContentRoot; // Root directory for frontend content + Ref m_Http; std::unique_ptr m_HttpRequestFilter; @@ -114,7 +126,6 @@ protected: private: void InitializeSecuritySettings(const ZenServerConfig& ServerOptions); }; - class ZenServerMain { public: diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua index 97615e322..db3ff2e2d 100644 --- a/src/zentest-appstub/xmake.lua +++ b/src/zentest-appstub/xmake.lua @@ -5,6 +5,9 @@ target("zentest-appstub") set_group("tests") add_headerfiles("**.h") add_files("*.cpp") + add_deps("zencore") + add_packages("vcpkg::gsl-lite") -- this should ideally be propagated by the zencore dependency + add_packages("vcpkg::mimalloc") if is_os("linux") then add_syslinks("pthread") diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 24cf21e97..926580d96 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -1,33 +1,408 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#include +#include +#include +#include +#include +#include +#include + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include +#endif + +#include + #include +#include #include #include #include +#include +#include +#include #include -using namespace std::chrono_literals; +using namespace std::literals; +using namespace zen; + +#if !defined(_MSC_VER) +# define _strnicmp strncasecmp // TEMPORARY WORKAROUND - should not be using this +#endif + +// Some basic functions to implement some test "compute" functions + +std::string +Rot13Function(std::string_view InputString) +{ + std::string OutputString{InputString}; + + std::transform(OutputString.begin(), + OutputString.end(), + OutputString.begin(), + [](std::string::value_type c) -> std::string::value_type { + if (c >= 'a' && c <= 'z') + { + return 'a' + (c - 'a' + 13) % 26; + } + else if (c >= 'A' && c <= 'Z') + { + return 'A' + (c - 'A' + 13) % 26; + } + else + { + return c; + } + }); + + return OutputString; +} + +std::string +ReverseFunction(std::string_view InputString) +{ + std::string OutputString{InputString}; + std::reverse(OutputString.begin(), OutputString.end()); + return OutputString; +} + +std::string +IdentityFunction(std::string_view InputString) +{ + return std::string{InputString}; +} + +std::string +NullFunction(std::string_view) +{ + return {}; +} + +zen::CbObject +DescribeFunctions() +{ + CbObjectWriter Versions; + Versions << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv); + + Versions.BeginArray("Functions"sv); + Versions.BeginObject(); + Versions << "Name"sv + << "Null"sv; + Versions << "Version"sv << Guid::FromString("00000000-0000-0000-0000-000000000000"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Identity"sv; + Versions << "Version"sv << Guid::FromString("11111111-1111-1111-1111-111111111111"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Rot13"sv; + Versions << "Version"sv << Guid::FromString("13131313-1313-1313-1313-131313131313"sv); + Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Reverse"sv; + Versions << "Version"sv << Guid::FromString("31313131-3131-3131-3131-313131313131"sv); + Versions.EndObject(); + Versions.EndArray(); + + return Versions.Save(); +} + +struct ContentResolver +{ + std::filesystem::path InputsRoot; + + CompressedBuffer ResolveChunk(IoHash Hash, uint64_t ExpectedSize) + { + std::filesystem::path ChunkPath = InputsRoot / Hash.ToHexString(); + IoBuffer ChunkBuffer = IoBufferBuilder::MakeFromFile(ChunkPath); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer AsCompressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkBuffer), RawHash, RawSize); + + if (RawSize != ExpectedSize) + { + throw std::runtime_error( + fmt::format("chunk size mismatch - expected {}, got {} for '{}'", ExpectedSize, ChunkBuffer.Size(), ChunkPath)); + } + if (RawHash != Hash) + { + throw std::runtime_error(fmt::format("chunk hash mismatch - expected {}, got {} for '{}'", Hash, RawHash, ChunkPath)); + } + + return AsCompressed; + } +}; + +zen::CbPackage +ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) +{ + auto Apply = [&](auto Func) { + zen::CbPackage Result; + auto Source = Action["Inputs"sv].AsObjectView()["Source"sv].AsObjectView(); + + IoHash InputRawHash = Source["RawHash"sv].AsHash(); + uint64_t InputRawSize = Source["RawSize"sv].AsUInt64(); + + zen::CompressedBuffer InputData = ChunkResolver.ResolveChunk(InputRawHash, InputRawSize); + SharedBuffer Input = InputData.Decompress(); + + std::string Output = Func(std::string_view(static_cast(Input.GetData()), Input.GetSize())); + zen::CompressedBuffer OutputData = + zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Output), OodleCompressor::Selkie, OodleCompressionLevel::HyperFast4); + IoHash OutputRawHash = OutputData.DecodeRawHash(); + + CbAttachment OutputAttachment(std::move(OutputData), OutputRawHash); + + CbObjectWriter Cbo; + Cbo.BeginArray("Values"sv); + Cbo.BeginObject(); + Cbo << "Id" << Oid{1, 2, 3}; + Cbo.AddAttachment("RawHash", OutputAttachment); + Cbo << "RawSize" << Output.size(); + Cbo.EndObject(); + Cbo.EndArray(); + + Result.SetObject(Cbo.Save()); + Result.AddAttachment(std::move(OutputAttachment)); + return Result; + }; + + std::string_view Function = Action["Function"sv].AsString(); + + if (Function == "Rot13"sv) + { + return Apply(Rot13Function); + } + else if (Function == "Reverse"sv) + { + return Apply(ReverseFunction); + } + else if (Function == "Identity"sv) + { + return Apply(IdentityFunction); + } + else if (Function == "Null"sv) + { + return Apply(NullFunction); + } + else + { + return {}; + } +} + +/* This implements a minimal application to help testing of process launch-related + functionality + + It also mimics the DDC2 worker command line interface, so it may be used to + exercise compute infrastructure. + */ int main(int argc, char* argv[]) { int ExitCode = 0; - for (int i = 0; i < argc; ++i) + try { - if (std::strncmp(argv[i], "-t=", 3) == 0) + std::filesystem::path BasePath = std::filesystem::current_path(); + std::filesystem::path InputPath = std::filesystem::current_path() / "Inputs"; + std::filesystem::path OutputPath = std::filesystem::current_path() / "Outputs"; + std::filesystem::path VersionPath = std::filesystem::current_path() / "Versions"; + std::vector ActionPaths; + + /* + GetSwitchValues(TEXT("-B="), ActionPathPatterns); + GetSwitchValues(TEXT("-Build="), ActionPathPatterns); + + GetSwitchValues(TEXT("-I="), InputDirectoryPaths); + GetSwitchValues(TEXT("-Input="), InputDirectoryPaths); + + GetSwitchValues(TEXT("-O="), OutputDirectoryPaths); + GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths); + + GetSwitchValues(TEXT("-V="), VersionPaths); + GetSwitchValues(TEXT("-Version="), VersionPaths); + */ + + auto SplitArg = [](const char* Arg) -> std::string_view { + std::string_view ArgView{Arg}; + if (auto SplitPos = ArgView.find_first_of('='); SplitPos != std::string_view::npos) + { + return ArgView.substr(SplitPos + 1); + } + else + { + return {}; + } + }; + + auto ParseIntArg = [](std::string_view Arg) -> int { + int Rv = 0; + const auto Result = std::from_chars(Arg.data(), Arg.data() + Arg.size(), Rv); + + if (Result.ec != std::errc{}) + { + throw std::invalid_argument(fmt::format("bad argument (not an integer): {}", Arg).c_str()); + } + + return Rv; + }; + + for (int i = 1; i < argc; ++i) + { + std::string_view Arg = argv[i]; + + if (Arg.compare(0, 1, "-")) + { + continue; + } + + if (std::strncmp(argv[i], "-t=", 3) == 0) + { + const int SleepTime = std::atoi(argv[i] + 3); + + printf("[zentest] sleeping for %ds...\n", SleepTime); + + std::this_thread::sleep_for(SleepTime * 1s); + } + else if (std::strncmp(argv[i], "-f=", 3) == 0) + { + // Force a "failure" process exit code to return to the invoker + + // This may throw for invalid arguments, which makes this useful for + // testing exception handling + std::string_view ErrorArg = SplitArg(argv[i]); + ExitCode = ParseIntArg(ErrorArg); + } + else if ((_strnicmp(argv[i], "-input=", 7) == 0) || (_strnicmp(argv[i], "-i=", 3) == 0)) + { + /* mimic DDC2 + + GetSwitchValues(TEXT("-I="), InputDirectoryPaths); + GetSwitchValues(TEXT("-Input="), InputDirectoryPaths); + */ + + std::string_view InputArg = SplitArg(argv[i]); + InputPath = InputArg; + } + else if ((_strnicmp(argv[i], "-output=", 8) == 0) || (_strnicmp(argv[i], "-o=", 3) == 0)) + { + /* mimic DDC2 handling of where files storing output chunk files are directed + + GetSwitchValues(TEXT("-O="), OutputDirectoryPaths); + GetSwitchValues(TEXT("-Output="), OutputDirectoryPaths); + */ + + std::string_view OutputArg = SplitArg(argv[i]); + OutputPath = OutputArg; + } + else if ((_strnicmp(argv[i], "-version=", 8) == 0) || (_strnicmp(argv[i], "-v=", 3) == 0)) + { + /* mimic DDC2 + + GetSwitchValues(TEXT("-V="), VersionPaths); + GetSwitchValues(TEXT("-Version="), VersionPaths); + */ + + std::string_view VersionArg = SplitArg(argv[i]); + VersionPath = VersionArg; + } + else if ((_strnicmp(argv[i], "-build=", 7) == 0) || (_strnicmp(argv[i], "-b=", 3) == 0)) + { + /* mimic DDC2 + + GetSwitchValues(TEXT("-B="), ActionPathPatterns); + GetSwitchValues(TEXT("-Build="), ActionPathPatterns); + */ + + std::string_view BuildActionArg = SplitArg(argv[i]); + std::filesystem::path ActionPath{BuildActionArg}; + ActionPaths.push_back(ActionPath); + + ExitCode = 0; + } + } + + // Emit version information + + if (!VersionPath.empty()) { - const int SleepTime = std::atoi(argv[i] + 3); + CbObjectWriter Version; + + Version << "BuildSystemVersion" << Guid::FromString("17fe280d-ccd8-4be8-a9d1-89c944a70969"sv); + + Version.BeginArray("Functions"); + + Version.BeginObject(); + Version << "Name" + << "Rot13" + << "Version" << Guid::FromString("13131313-1313-1313-1313-131313131313"sv); + Version.EndObject(); - printf("[zentest] sleeping for %ds...\n", SleepTime); + Version.BeginObject(); + Version << "Name" + << "Reverse" + << "Version" << Guid::FromString("98765432-1000-0000-0000-000000000000"sv); + Version.EndObject(); - std::this_thread::sleep_for(SleepTime * 1s); + Version.BeginObject(); + Version << "Name" + << "Identity" + << "Version" << Guid::FromString("11111111-1111-1111-1111-111111111111"sv); + Version.EndObject(); + + Version.BeginObject(); + Version << "Name" + << "Null" + << "Version" << Guid::FromString("00000000-0000-0000-0000-000000000000"sv); + Version.EndObject(); + + Version.EndArray(); + CbObject VersionObject = Version.Save(); + + BinaryWriter Writer; + zen::SaveCompactBinary(Writer, VersionObject); + zen::WriteFile(VersionPath, IoBufferBuilder::MakeFromMemory(Writer.GetView())); } - else if (std::strncmp(argv[i], "-f=", 3) == 0) + + // Evaluate actions + + ContentResolver Resolver; + Resolver.InputsRoot = InputPath; + + for (std::filesystem::path ActionPath : ActionPaths) { - ExitCode = std::atoi(argv[i] + 3); + IoBuffer ActionDescBuffer = ReadFile(ActionPath).Flatten(); + CbObject ActionDesc = LoadCompactBinaryObject(ActionDescBuffer); + CbPackage Result = ExecuteFunction(ActionDesc, Resolver); + CbObject ResultObject = Result.GetObject(); + + BinaryWriter Writer; + zen::SaveCompactBinary(Writer, ResultObject); + zen::WriteFile(ActionPath.replace_extension(".output"), IoBufferBuilder::MakeFromMemory(Writer.GetView())); + + // Also marshal outputs + + for (const auto& Attachment : Result.GetAttachments()) + { + const CompositeBuffer& AttachmentBuffer = Attachment.AsCompressedBinary().GetCompressed(); + zen::WriteFile(OutputPath / Attachment.GetHash().ToHexString(), AttachmentBuffer.Flatten().AsIoBuffer()); + } } } + catch (std::exception& Ex) + { + printf("[zentest] exception caught in main: '%s'\n", Ex.what()); + + ExitCode = 99; + } printf("[zentest] exiting with exit code: %d\n", ExitCode); -- cgit v1.2.3 From a948ff9570a5a9d8ec424639cba6f973247a0372 Mon Sep 17 00:00:00 2001 From: zousar Date: Wed, 18 Feb 2026 23:15:09 -0700 Subject: entry.js handles missing/native items more gracefully --- src/zenserver/frontend/html/pages/cookartifacts.js | 20 ++++++++++++++++---- src/zenserver/frontend/html/pages/entry.js | 16 ++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html/pages/cookartifacts.js b/src/zenserver/frontend/html/pages/cookartifacts.js index 6c36c7f32..f2ae094b9 100644 --- a/src/zenserver/frontend/html/pages/cookartifacts.js +++ b/src/zenserver/frontend/html/pages/cookartifacts.js @@ -261,13 +261,25 @@ export class Page extends ZenPage { const row = runtime_table.add_row(dep); // Make Path clickable to navigate to entry - row.get_cell(0).text(dep).on_click((opkey) => { - window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; - }, dep); + if (this._should_link_dependency(dep)) + { + row.get_cell(0).text(dep).on_click((opkey) => { + window.location = `?page=entry&project=${project}&oplog=${oplog}&opkey=${opkey.toLowerCase()}`; + }, dep); + } } } } + _should_link_dependency(name) + { + // Exclude dependencies starting with /Script/ (code-defined entries) - case insensitive + if (name && name.toLowerCase().startsWith("/script/")) + return false; + + return true; + } + _build_dependency_section(parent_section, title, dependencies, stored_key) { const section = parent_section.add_section(title); @@ -338,7 +350,7 @@ export class Page extends ZenPage const row = table.add_row(...row_values); // Make Name field clickable for Package, TransitiveBuild, and RedirectionTarget - if (should_link && name_col_index >= 0 && dep.Name) + if (should_link && name_col_index >= 0 && dep.Name && this._should_link_dependency(dep.Name)) { const project = this.get_param("project"); const oplog = this.get_param("oplog"); diff --git a/src/zenserver/frontend/html/pages/entry.js b/src/zenserver/frontend/html/pages/entry.js index c4746bf52..f418b17ba 100644 --- a/src/zenserver/frontend/html/pages/entry.js +++ b/src/zenserver/frontend/html/pages/entry.js @@ -181,6 +181,22 @@ export class Page extends ZenPage async _build_page() { var entry = await this._entry; + + // Check if entry exists + if (!entry || entry.as_object().find("entry") == null) + { + const opkey = this.get_param("opkey"); + var section = this.add_section("Entry Not Found"); + section.tag("p").text(`The entry "${opkey}" is not present in this dataset.`); + section.tag("p").text("This could mean:"); + const list = section.tag("ul"); + list.tag("li").text("The entry is for an instance defined in code"); + list.tag("li").text("The entry has not been added to the oplog yet"); + list.tag("li").text("The entry key is misspelled"); + list.tag("li").text("The entry was removed or never existed"); + return; + } + entry = entry.as_object().find("entry").as_object(); const name = entry.find("key").as_value(); -- cgit v1.2.3 From a1f158e14761767f83469e9e522cf542f9ad91e2 Mon Sep 17 00:00:00 2001 From: zousar Date: Wed, 18 Feb 2026 23:17:56 -0700 Subject: updatefrontend --- src/zenserver/frontend/html.zip | Bin 182962 -> 183939 bytes 1 file changed, 0 insertions(+), 0 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index 67752fbc2..d70a5a62b 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ -- cgit v1.2.3 From cae12611580c6c28b1362fa28181b8f388516a47 Mon Sep 17 00:00:00 2001 From: zousar Date: Thu, 19 Feb 2026 13:55:44 -0700 Subject: icon and header logo changes --- src/UnrealEngine.ico | Bin 65288 -> 0 bytes src/zen.ico | Bin 0 -> 12957 bytes src/zen/zen.rc | 2 +- src/zenserver/frontend/html/epicgames.ico | Bin 0 -> 65288 bytes src/zenserver/frontend/html/favicon.ico | Bin 65288 -> 12957 bytes src/zenserver/frontend/html/pages/page.js | 24 ++++++------------------ src/zenserver/frontend/html/zen.css | 16 +++++++++++++++- src/zenserver/zenserver.rc | 2 +- 8 files changed, 23 insertions(+), 21 deletions(-) delete mode 100644 src/UnrealEngine.ico create mode 100644 src/zen.ico create mode 100644 src/zenserver/frontend/html/epicgames.ico (limited to 'src') diff --git a/src/UnrealEngine.ico b/src/UnrealEngine.ico deleted file mode 100644 index 1cfa301a2..000000000 Binary files a/src/UnrealEngine.ico and /dev/null differ diff --git a/src/zen.ico b/src/zen.ico new file mode 100644 index 000000000..f7fb251b5 Binary files /dev/null and b/src/zen.ico differ diff --git a/src/zen/zen.rc b/src/zen/zen.rc index 661d75011..0617681a7 100644 --- a/src/zen/zen.rc +++ b/src/zen/zen.rc @@ -7,7 +7,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US #pragma code_page(1252) -101 ICON "..\\UnrealEngine.ico" +101 ICON "..\\zen.ico" VS_VERSION_INFO VERSIONINFO FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 diff --git a/src/zenserver/frontend/html/epicgames.ico b/src/zenserver/frontend/html/epicgames.ico new file mode 100644 index 000000000..1cfa301a2 Binary files /dev/null and b/src/zenserver/frontend/html/epicgames.ico differ diff --git a/src/zenserver/frontend/html/favicon.ico b/src/zenserver/frontend/html/favicon.ico index 1cfa301a2..f7fb251b5 100644 Binary files a/src/zenserver/frontend/html/favicon.ico and b/src/zenserver/frontend/html/favicon.ico differ diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 3ec0248cb..3c2d3619a 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -70,24 +70,12 @@ export class ZenPage extends PageBase { var root = parent.tag().id("branding"); - const zen_store = root.tag("pre").id("logo").text( - "_________ _______ __\n" + - "\\____ /___ ___ / ___// |__ ___ ______ ____\n" + - " / __/ __ \\ / \\ \\___ \\\\_ __// \\\\_ \\/ __ \\\n" + - " / \\ __// | \\/ \\| | ( - )| |\\/\\ __/\n" + - "/______/\\___/\\__|__/\\______/|__| \\___/ |__| \\___|" - ); - zen_store.tag().id("go_home").on_click(() => window.location.search = ""); - - root.tag("img").attr("src", "favicon.ico").id("ue_logo"); - - /* - _________ _______ __ - \____ /___ ___ / ___// |__ ___ ______ ____ - / __/ __ \ / \ \___ \\_ __// \\_ \/ __ \ - / \ __// | \/ \| | ( - )| |\/\ __/ - /______/\___/\__|__/\______/|__| \___/ |__| \___| - */ + const logo_container = root.tag("div").id("logo"); + logo_container.tag("img").attr("src", "favicon.ico").id("zen_icon"); + logo_container.tag("span").id("zen_text").text("zenserver"); + logo_container.tag().id("go_home").on_click(() => window.location.search = ""); + + root.tag("img").attr("src", "epicgames.ico").id("epic_logo"); } set_title(...args) diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index 34c265610..702bf9aa6 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -365,6 +365,20 @@ a { margin: auto; user-select: none; position: relative; + display: flex; + align-items: center; + gap: 0.8em; + + #zen_icon { + width: 3em; + height: 3em; + } + + #zen_text { + font-size: 2em; + font-weight: bold; + letter-spacing: 0.05em; + } #go_home { width: 100%; @@ -379,7 +393,7 @@ a { filter: drop-shadow(0 0.15em 0.1em var(--theme_p2)); } - #ue_logo { + #epic_logo { position: absolute; top: 1em; right: 0; diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc index e0003ea8f..f353bd9cc 100644 --- a/src/zenserver/zenserver.rc +++ b/src/zenserver/zenserver.rc @@ -28,7 +28,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US // Icon with lowest ID value placed first to ensure application icon // remains consistent on all systems. -IDI_ICON1 ICON "..\\UnrealEngine.ico" +IDI_ICON1 ICON "..\\zen.ico" #endif // English (United States) resources ///////////////////////////////////////////////////////////////////////////// -- cgit v1.2.3 From ee26e5af2ced0987fbdf666dc6bce7c2074e925f Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 20 Feb 2026 09:05:23 +0100 Subject: GC - fix handling of attachment ranges, http access token expiration, lock file retry logic (#766) * GC - fix handling of attachment ranges * fix trace/log strings * fix HTTP access token expiration time logic * added missing lock retry in zenserver startup --- src/zenhttp/httpclientauth.cpp | 2 +- src/zenhttp/servers/httpparser.cpp | 9 ++++++--- src/zenserver/compute/computeserver.cpp | 6 +++--- src/zenserver/hub/zenhubserver.cpp | 2 +- src/zenserver/zenserver.cpp | 2 ++ src/zenstore/gc.cpp | 7 ++++--- src/zenstore/include/zenstore/gc.h | 2 +- 7 files changed, 18 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index 72df12d02..02e1b57e2 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -170,7 +170,7 @@ namespace zen { namespace httpclientauth { time_t UTCTime = timegm(&Time); HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); - ExpireTime += std::chrono::microseconds(Millisecond); + ExpireTime += std::chrono::milliseconds(Millisecond); return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index be5befcd2..f0485aa25 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -226,6 +226,8 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) NormalizedUrl.append(Url, UrlIndex); } + // NOTE: this check is redundant given the enclosing if, + // need to verify the intent of this code if (!LastCharWasSeparator) { NormalizedUrl.push_back('/'); @@ -310,6 +312,7 @@ HttpRequestParser::OnHeadersComplete() if (ContentLength) { + // TODO: should sanity-check content length here m_BodyBuffer = IoBuffer(ContentLength); } @@ -329,9 +332,9 @@ HttpRequestParser::OnHeadersComplete() int HttpRequestParser::OnBody(const char* Data, size_t Bytes) { - if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size()) { - ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); return 1; } @@ -342,7 +345,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { if (m_BodyPosition != m_BodyBuffer.Size()) { - ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); return 1; } } diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 173f56386..0f9ef0287 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -82,7 +82,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ ZEN_TRACE_CPU("ZenComputeServer::Initialize"); ZEN_MEMSCOPE(GetZenserverTag()); - ZEN_INFO(ZEN_APP_NAME " initializing in HUB server mode"); + ZEN_INFO(ZEN_APP_NAME " initializing in COMPUTE server mode"); const int EffectiveBasePort = ZenServerBase::Initialize(ServerConfig, ServerEntry); if (EffectiveBasePort < 0) @@ -91,7 +91,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ } // This is a workaround to make sure we can have automated tests. Without - // this the ranges for different child zen hub processes could overlap with + // this the ranges for different child zen compute processes could overlap with // the main test range. ZenServerEnvironment::SetBaseChildId(1000); @@ -109,7 +109,7 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ void ZenComputeServer::Cleanup() { - ZEN_TRACE_CPU("ZenStorageServer::Cleanup"); + ZEN_TRACE_CPU("ZenComputeServer::Cleanup"); ZEN_INFO(ZEN_APP_NAME " cleaning up"); try { diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 7a4ba951d..d0a0db417 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -105,7 +105,7 @@ ZenHubServer::Initialize(const ZenHubServerConfig& ServerConfig, ZenServerState: void ZenHubServer::Cleanup() { - ZEN_TRACE_CPU("ZenStorageServer::Cleanup"); + ZEN_TRACE_CPU("ZenHubServer::Cleanup"); ZEN_INFO(ZEN_APP_NAME " cleaning up"); try { diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 7bf6126df..5fd35d9b4 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -617,6 +617,8 @@ ZenServerMain::Run() { ZEN_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); Sleep(500); + + m_LockFile.Create(LockFilePath, MakeLockData(false), Ec); if (Ec) { ZEN_WARN(ZEN_APP_NAME " exiting, unable to grab lock at '{}' (reason: '{}')", LockFilePath, Ec.message()); diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp index 14caa5abf..c3bdc59f0 100644 --- a/src/zenstore/gc.cpp +++ b/src/zenstore/gc.cpp @@ -1494,7 +1494,8 @@ GcManager::CollectGarbage(const GcSettings& Settings) GcReferenceValidatorStats& Stats = Result.ReferenceValidatorStats[It.second].second; try { - // Go through all the ReferenceCheckers to see if the list of Cids the collector selected are referenced or + // Go through all the ReferenceCheckers to see if the list of Cids the collector selected + // are referenced or not SCOPED_TIMER(Stats.ElapsedMS = std::chrono::milliseconds(Timer.GetElapsedTimeMs());); ReferenceValidator->Validate(Ctx, Stats); } @@ -1952,7 +1953,7 @@ GcScheduler::AppendGCLog(std::string_view Id, GcClock::TimePoint StartTime, cons Writer << "SingleThread"sv << Settings.SingleThread; Writer << "CompactBlockUsageThresholdPercent"sv << Settings.CompactBlockUsageThresholdPercent; Writer << "AttachmentRangeMin"sv << Settings.AttachmentRangeMin; - Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMin; + Writer << "AttachmentRangeMax"sv << Settings.AttachmentRangeMax; Writer << "ForceStoreCacheAttachmentMetaData"sv << Settings.StoreCacheAttachmentMetaData; Writer << "ForceStoreProjectAttachmentMetaData"sv << Settings.StoreProjectAttachmentMetaData; Writer << "EnableValidation"sv << Settings.EnableValidation; @@ -2893,7 +2894,7 @@ GcScheduler::CollectGarbage(const GcClock::TimePoint& CacheExpireTime, { m_LastFullGCV2Result = Result; m_LastFullAttachmentRangeMin = AttachmentRangeMin; - m_LastFullAttachmentRangeMin = AttachmentRangeMax; + m_LastFullAttachmentRangeMax = AttachmentRangeMax; } Diff.DiskSize = Result.CompactStoresStatSum.RemovedDisk; Diff.MemorySize = Result.ReferencerStatSum.RemoveExpiredDataStats.FreedMemory; diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h index 734d2e5a7..794f50d96 100644 --- a/src/zenstore/include/zenstore/gc.h +++ b/src/zenstore/include/zenstore/gc.h @@ -238,7 +238,7 @@ bool FilterReferences(GcCtx& Ctx, std::string_view Context, std::vector Date: Fri, 20 Feb 2026 09:07:00 +0100 Subject: fix MakeSafeAbsolutePathInPlace mis-spelling (#765) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (was MakeSafeAbsolutePathÍnPlace - note accent) Also fixed misleading comments on multiple functions in filesystem.h --- src/zen/authutils.cpp | 2 +- src/zen/cmds/builds_cmd.cpp | 44 ++++++++++++++++---------------- src/zen/cmds/print_cmd.cpp | 4 +-- src/zen/cmds/projectstore_cmd.cpp | 2 +- src/zen/cmds/wipe_cmd.cpp | 2 +- src/zen/cmds/workspaces_cmd.cpp | 2 +- src/zencore/filesystem.cpp | 4 +-- src/zencore/include/zencore/filesystem.h | 38 +++++++++++++-------------- 8 files changed, 49 insertions(+), 49 deletions(-) (limited to 'src') diff --git a/src/zen/authutils.cpp b/src/zen/authutils.cpp index 31db82efd..16427acf5 100644 --- a/src/zen/authutils.cpp +++ b/src/zen/authutils.cpp @@ -233,7 +233,7 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, } else if (!m_AccessTokenPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_AccessTokenPath); + MakeSafeAbsolutePathInPlace(m_AccessTokenPath); std::string ResolvedAccessToken = ReadAccessTokenFromJsonFile(m_AccessTokenPath); if (!ResolvedAccessToken.empty()) { diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 59b209384..8dfe1093f 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -2680,7 +2680,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_SystemRootDir = PickDefaultSystemRootDirectory(); } - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); }; ParseSystemOptions(); @@ -2729,7 +2729,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--host', '--url', '--override-host' or '--storage-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_StoragePath); + MakeSafeAbsolutePathInPlace(m_StoragePath); }; auto ParseOutputOptions = [&]() { @@ -2947,7 +2947,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--local-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_Path); + MakeSafeAbsolutePathInPlace(m_Path); }; auto ParseFileFilters = [&](std::vector& OutIncludeWildcards, std::vector& OutExcludeWildcards) { @@ -3004,7 +3004,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { throw OptionParseException("'--compare-path' is required", SubOption->help()); } - MakeSafeAbsolutePathÍnPlace(m_DiffPath); + MakeSafeAbsolutePathInPlace(m_DiffPath); }; auto ParseBlobHash = [&]() -> IoHash { @@ -3105,7 +3105,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!m_BuildMetadataPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_BuildMetadataPath); + MakeSafeAbsolutePathInPlace(m_BuildMetadataPath); IoBuffer MetaDataJson = ReadFile(m_BuildMetadataPath).Flatten(); std::string_view Json(reinterpret_cast(MetaDataJson.GetData()), MetaDataJson.GetSize()); std::string JsonError; @@ -3202,8 +3202,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (SubOption == &m_ListOptions) { - MakeSafeAbsolutePathÍnPlace(m_ListQueryPath); - MakeSafeAbsolutePathÍnPlace(m_ListResultPath); + MakeSafeAbsolutePathInPlace(m_ListQueryPath); + MakeSafeAbsolutePathInPlace(m_ListResultPath); if (!m_ListResultPath.empty()) { @@ -3255,7 +3255,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); }); @@ -3294,7 +3294,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (SubOption == &m_ListBlocksOptions) { - MakeSafeAbsolutePathÍnPlace(m_ListResultPath); + MakeSafeAbsolutePathInPlace(m_ListResultPath); if (!m_ListResultPath.empty()) { @@ -3316,7 +3316,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this]() { CleanAndRemoveDirectory(GetSmallWorkerPool(EWorkloadType::Burst), m_ZenFolderPath); }); @@ -3387,8 +3387,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3532,7 +3532,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -3632,7 +3632,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); BuildStorageBase::Statistics StorageStats; BuildStorageCache::Statistics StorageCacheStats; @@ -3652,7 +3652,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::unique_ptr StructuredOutput; if (!m_LsResultPath.empty()) { - MakeSafeAbsolutePathÍnPlace(m_LsResultPath); + MakeSafeAbsolutePathInPlace(m_LsResultPath); StructuredOutput = std::make_unique(); } @@ -3696,7 +3696,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ParsePath(); ParseDiffPath(); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); std::vector ExcludeFolders = DefaultExcludeFolders; std::vector ExcludeExtensions = DefaultExcludeExtensions; @@ -3745,7 +3745,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3828,7 +3828,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3883,7 +3883,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = std::filesystem::current_path() / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); CreateDirectories(m_ZenFolderPath); auto _ = MakeGuard([this, &Workers]() { CleanAndRemoveDirectory(Workers.GetIOWorkerPool(), m_ZenFolderPath); }); @@ -3933,7 +3933,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); EPartialBlockRequestMode PartialBlockRequestMode = ParseAllowPartialBlockRequests(); @@ -4083,8 +4083,8 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { m_ZenFolderPath = m_Path / ZenFolderName; } - MakeSafeAbsolutePathÍnPlace(m_ZenFolderPath); - MakeSafeAbsolutePathÍnPlace(m_ChunkingCachePath); + MakeSafeAbsolutePathInPlace(m_ZenFolderPath); + MakeSafeAbsolutePathInPlace(m_ChunkingCachePath); StorageInstance Storage = CreateBuildStorage(StorageStats, StorageCacheStats, diff --git a/src/zen/cmds/print_cmd.cpp b/src/zen/cmds/print_cmd.cpp index 030cc8b66..c6b250fdf 100644 --- a/src/zen/cmds/print_cmd.cpp +++ b/src/zen/cmds/print_cmd.cpp @@ -84,7 +84,7 @@ PrintCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } else { - MakeSafeAbsolutePathÍnPlace(m_Filename); + MakeSafeAbsolutePathInPlace(m_Filename); Fc = ReadFile(m_Filename); } @@ -244,7 +244,7 @@ PrintPackageCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar if (m_Filename.empty()) throw OptionParseException("'--source' is required", m_Options.help()); - MakeSafeAbsolutePathÍnPlace(m_Filename); + MakeSafeAbsolutePathInPlace(m_Filename); FileContents Fc = ReadFile(m_Filename); IoBuffer Data = Fc.Flatten(); CbPackage Package; diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index 4885fd363..4de6ad25c 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -2430,7 +2430,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a { m_SystemRootDir = PickDefaultSystemRootDirectory(); } - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); }; ParseSystemOptions(); diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp index adf0e61f0..a5029e1c5 100644 --- a/src/zen/cmds/wipe_cmd.cpp +++ b/src/zen/cmds/wipe_cmd.cpp @@ -549,7 +549,7 @@ WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ProgressMode = (IsVerbose || m_PlainProgress) ? ProgressBar::Mode::Plain : ProgressBar::Mode::Pretty; BoostWorkerThreads = m_BoostWorkerThreads; - MakeSafeAbsolutePathÍnPlace(m_Directory); + MakeSafeAbsolutePathInPlace(m_Directory); if (!IsDir(m_Directory)) { diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 6e6f5d863..2661ac9da 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -398,7 +398,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** } else { - MakeSafeAbsolutePathÍnPlace(m_SystemRootDir); + MakeSafeAbsolutePathInPlace(m_SystemRootDir); } std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 92a065707..1a4ee4b9b 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3069,7 +3069,7 @@ SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly) } void -MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path) +MakeSafeAbsolutePathInPlace(std::filesystem::path& Path) { if (!Path.empty()) { @@ -3091,7 +3091,7 @@ std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path) { std::filesystem::path Tmp(Path); - MakeSafeAbsolutePathÍnPlace(Tmp); + MakeSafeAbsolutePathInPlace(Tmp); return Tmp; } diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index f28863679..16e2b59f8 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -64,80 +64,80 @@ std::filesystem::path PathFromHandle(void* NativeHandle, std::error_code& Ec); */ std::filesystem::path CanonicalPath(std::filesystem::path InPath, std::error_code& Ec); -/** Query file size +/** Check if a path exists and is a regular file (throws) */ bool IsFile(const std::filesystem::path& Path); -/** Query file size +/** Check if a path exists and is a regular file (does not throw) */ bool IsFile(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Check if a path exists and is a directory (throws) */ bool IsDir(const std::filesystem::path& Path); -/** Query file size +/** Check if a path exists and is a directory (does not throw) */ bool IsDir(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Delete file at path, if it exists (throws) */ bool RemoveFile(const std::filesystem::path& Path); -/** Query file size +/** Delete file at path, if it exists (does not throw) */ bool RemoveFile(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Delete directory at path, if it exists (throws) */ bool RemoveDir(const std::filesystem::path& Path); -/** Query file size +/** Delete directory at path, if it exists (does not throw) */ bool RemoveDir(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size +/** Query file size (throws) */ uint64_t FileSizeFromPath(const std::filesystem::path& Path); -/** Query file size +/** Query file size (does not throw) */ uint64_t FileSizeFromPath(const std::filesystem::path& Path, std::error_code& Ec); -/** Query file size from native file handle +/** Query file size from native file handle (throws) */ uint64_t FileSizeFromHandle(void* NativeHandle); -/** Query file size from native file handle +/** Query file size from native file handle (does not throw) */ uint64_t FileSizeFromHandle(void* NativeHandle, std::error_code& Ec); /** Get a native time tick of last modification time */ -uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec); +uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename); /** Get a native time tick of last modification time */ -uint64_t GetModificationTickFromPath(const std::filesystem::path& Filename); +uint64_t GetModificationTickFromHandle(void* NativeHandle, std::error_code& Ec); bool TryGetFileProperties(const std::filesystem::path& Path, uint64_t& OutSize, uint64_t& OutModificationTick, uint32_t& OutNativeModeOrAttributes); -/** Move a file, if the files are not on the same drive the function will fail +/** Move/rename a file, if the files are not on the same drive the function will fail (throws) */ void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); -/** Move a file, if the files are not on the same drive the function will fail +/** Move/rename a file, if the files are not on the same drive the function will fail */ void RenameFile(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); -/** Move a directory, if the files are not on the same drive the function will fail +/** Move/rename a directory, if the files are not on the same drive the function will fail (throws) */ void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath); -/** Move a directory, if the files are not on the same drive the function will fail +/** Move/rename a directory, if the files are not on the same drive the function will fail */ void RenameDirectory(const std::filesystem::path& SourcePath, const std::filesystem::path& TargetPath, std::error_code& Ec); @@ -421,7 +421,7 @@ uint32_t MakeFileModeReadOnly(uint32_t FileMode, bool ReadOnly); bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly, std::error_code& Ec); bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly); -void MakeSafeAbsolutePathÍnPlace(std::filesystem::path& Path); +void MakeSafeAbsolutePathInPlace(std::filesystem::path& Path); [[nodiscard]] std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path); class SharedMemory -- cgit v1.2.3 From 17898ec8a7ce42c0da27ac50c5c65aeb447c6374 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 20 Feb 2026 10:29:42 +0100 Subject: fix plain progress bar (#768) * fix plain progress not updating current state --- src/zen/progressbar.cpp | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 83606df67..1ee1d1e71 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -245,6 +245,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) const std::string Details = (!NewState.Details.empty()) ? fmt::format(": {}", NewState.Details) : ""; const std::string Output = fmt::format("{} {}% ({}){}\n", Task, PercentDone, NiceTimeSpanMs(ElapsedTimeMS), Details); OutputToConsoleRaw(Output); + m_State = NewState; } else if (m_Mode == Mode::Pretty) { -- cgit v1.2.3 From 80bc5a53fe9077bc20d287b912f6476db233110c Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 20 Feb 2026 10:31:31 +0100 Subject: fix builds download indexing timer (#769) * fix build download indexing timer log --- src/zen/cmds/builds_cmd.cpp | 7 +++++++ src/zenremotestore/builds/buildstorageoperations.cpp | 7 ------- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 8dfe1093f..849259013 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -1467,9 +1467,16 @@ namespace { ZEN_CONSOLE("Downloading build {}, parts:{} to '{}' ({})", BuildId, BuildPartString.ToView(), Path, NiceBytes(RawSize)); } + Stopwatch IndexTimer; + const ChunkedContentLookup LocalLookup = BuildChunkedContentLookup(LocalState.State.ChunkedContent); const ChunkedContentLookup RemoteLookup = BuildChunkedContentLookup(RemoteContent); + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); + } + ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Download, TaskSteps::StepCount); BuildsOperationUpdateFolder Updater( diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index ade431393..72e06767a 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -579,13 +579,6 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) CreateDirectories(m_TempDownloadFolderPath); CreateDirectories(m_TempBlockFolderPath); - Stopwatch IndexTimer; - - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Indexed local and remote content in {}", NiceTimeSpanMs(IndexTimer.GetElapsedTimeMs())); - } - Stopwatch CacheMappingTimer; std::vector> SequenceIndexChunksLeftToWriteCounters(m_RemoteContent.ChunkedContent.SequenceRawHashes.size()); -- cgit v1.2.3 From da4826d560a66b8a5f09158a93c83caa12348c7b Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 20 Feb 2026 10:32:32 +0100 Subject: move partial chunk block anailsys to chunkblock.h/cpp (#767) --- .../builds/buildstorageoperations.cpp | 882 ++++++--------------- src/zenremotestore/chunking/chunkblock.cpp | 540 ++++++++++++- .../zenremotestore/builds/buildstorageoperations.h | 50 +- .../include/zenremotestore/chunking/chunkblock.h | 94 +++ 4 files changed, 887 insertions(+), 679 deletions(-) (limited to 'src') diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 72e06767a..4f1b07c37 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -899,343 +899,213 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) CheckRequiredDiskSpace(RemotePathToRemoteIndex); + BlobsExistsResult ExistsResult; { - ZEN_TRACE_CPU("WriteChunks"); - - m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); + ChunkBlockAnalyser BlockAnalyser(m_LogOutput, + m_BlockDescriptions, + ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, .IsVerbose = m_Options.IsVerbose}); - Stopwatch WriteTimer; + std::vector NeededBlocks = BlockAnalyser.GetNeeded( + m_RemoteLookup.ChunkHashToChunkIndex, + [&](uint32_t RemoteChunkIndex) -> bool { return RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex]; }); - FilteredRate FilteredDownloadedBytesPerSecond; - FilteredRate FilteredWrittenBytesPerSecond; - - std::unique_ptr WriteProgressBarPtr( - m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing")); - OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr); - ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); - - struct LooseChunkHashWorkData - { - std::vector ChunkTargetPtrs; - uint32_t RemoteChunkIndex = (uint32_t)-1; - }; - - std::vector LooseChunkHashWorks; - TotalPartWriteCount += CopyChunkDatas.size(); - TotalPartWriteCount += ScavengedSequenceCopyOperations.size(); + std::vector FetchBlockIndexes; + std::vector CachedChunkBlockIndexes; - for (const IoHash ChunkHash : m_LooseChunkHashes) { - auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); - ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); - const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) + ZEN_TRACE_CPU("BlockCacheFileExists"); + for (const ChunkBlockAnalyser::NeededBlock& NeededBlock : NeededBlocks) { - if (m_Options.IsVerbose) + if (m_Options.PrimeCacheOnly) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash); - } - continue; - } - bool NeedsCopy = true; - if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) - { - std::vector ChunkTargetPtrs = - GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - - if (ChunkTargetPtrs.empty()) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Skipping chunk {} due to cache reuse", ChunkHash); - } + FetchBlockIndexes.push_back(NeededBlock.BlockIndex); } else { - TotalRequestCount++; - TotalPartWriteCount++; - LooseChunkHashWorks.push_back( - LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); - } - } - } - - uint32_t BlockCount = gsl::narrow(m_BlockDescriptions.size()); - - std::vector ChunkIsPickedUpByBlock(m_RemoteContent.ChunkedContent.ChunkHashes.size(), false); - auto GetNeededChunkBlockIndexes = [this, &RemoteChunkIndexNeedsCopyFromSourceFlags, &ChunkIsPickedUpByBlock]( - const ChunkBlockDescription& BlockDescription) { - ZEN_TRACE_CPU("GetNeededChunkBlockIndexes"); - std::vector NeededBlockChunkIndexes; - for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) - { - const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; - if (auto It = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); It != m_RemoteLookup.ChunkHashToChunkIndex.end()) - { - const uint32_t RemoteChunkIndex = It->second; - if (!ChunkIsPickedUpByBlock[RemoteChunkIndex]) + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; + bool UsingCachedBlock = false; + if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) { - if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex]) + TotalPartWriteCount++; + + std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); + if (IsFile(BlockPath)) { - ChunkIsPickedUpByBlock[RemoteChunkIndex] = true; - NeededBlockChunkIndexes.push_back(ChunkBlockIndex); + CachedChunkBlockIndexes.push_back(NeededBlock.BlockIndex); + UsingCachedBlock = true; } } - } - else - { - ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash); + if (!UsingCachedBlock) + { + FetchBlockIndexes.push_back(NeededBlock.BlockIndex); + } } } - return NeededBlockChunkIndexes; - }; + } - std::vector CachedChunkBlockIndexes; - std::vector FetchBlockIndexes; - std::vector> AllBlockChunkIndexNeeded; + std::vector NeededLooseChunkIndexes; - for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; BlockIndex++) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - std::vector BlockChunkIndexNeeded = GetNeededChunkBlockIndexes(BlockDescription); - if (!BlockChunkIndexNeeded.empty()) + NeededLooseChunkIndexes.reserve(m_LooseChunkHashes.size()); + for (uint32_t LooseChunkIndex = 0; LooseChunkIndex < m_LooseChunkHashes.size(); LooseChunkIndex++) { - if (m_Options.PrimeCacheOnly) + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; + + if (RemoteChunkIndexNeedsCopyFromLocalFileFlags[RemoteChunkIndex]) { - FetchBlockIndexes.push_back(BlockIndex); + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Skipping chunk {} due to cache reuse", + m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); + } + continue; } - else + + bool NeedsCopy = true; + if (RemoteChunkIndexNeedsCopyFromSourceFlags[RemoteChunkIndex].compare_exchange_strong(NeedsCopy, false)) { - bool UsingCachedBlock = false; - if (auto It = CachedBlocksFound.find(BlockDescription.BlockHash); It != CachedBlocksFound.end()) + uint64_t WriteCount = GetChunkWriteCount(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); + if (WriteCount == 0) { - TotalPartWriteCount++; - - std::filesystem::path BlockPath = m_TempBlockFolderPath / BlockDescription.BlockHash.ToHexString(); - if (IsFile(BlockPath)) + if (m_Options.IsVerbose) { - CachedChunkBlockIndexes.push_back(BlockIndex); - UsingCachedBlock = true; + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Skipping chunk {} due to cache reuse", + m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]); } } - if (!UsingCachedBlock) + else { - FetchBlockIndexes.push_back(BlockIndex); + NeededLooseChunkIndexes.push_back(LooseChunkIndex); } } } - AllBlockChunkIndexNeeded.emplace_back(std::move(BlockChunkIndexNeeded)); } - BlobsExistsResult ExistsResult; - if (m_Storage.BuildCacheStorage) { ZEN_TRACE_CPU("BlobCacheExistCheck"); Stopwatch Timer; - tsl::robin_set BlobHashesSet; + std::vector BlobHashes; + BlobHashes.reserve(NeededLooseChunkIndexes.size() + FetchBlockIndexes.size()); - BlobHashesSet.reserve(LooseChunkHashWorks.size() + FetchBlockIndexes.size()); - for (LooseChunkHashWorkData& LooseChunkHashWork : LooseChunkHashWorks) + for (const uint32_t LooseChunkIndex : NeededLooseChunkIndexes) { - BlobHashesSet.insert(m_RemoteContent.ChunkedContent.ChunkHashes[LooseChunkHashWork.RemoteChunkIndex]); + BlobHashes.push_back(m_LooseChunkHashes[LooseChunkIndex]); } + for (uint32_t BlockIndex : FetchBlockIndexes) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - BlobHashesSet.insert(BlockDescription.BlockHash); + BlobHashes.push_back(m_BlockDescriptions[BlockIndex].BlockHash); } - if (!BlobHashesSet.empty()) - { - const std::vector BlobHashes(BlobHashesSet.begin(), BlobHashesSet.end()); - const std::vector CacheExistsResult = - m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); + const std::vector CacheExistsResult = + m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); - if (CacheExistsResult.size() == BlobHashes.size()) + if (CacheExistsResult.size() == BlobHashes.size()) + { + ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size()); + for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) { - ExistsResult.ExistingBlobs.reserve(CacheExistsResult.size()); - for (size_t BlobIndex = 0; BlobIndex < BlobHashes.size(); BlobIndex++) + if (CacheExistsResult[BlobIndex].HasBody) { - if (CacheExistsResult[BlobIndex].HasBody) - { - ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]); - } + ExistsResult.ExistingBlobs.insert(BlobHashes[BlobIndex]); } } - ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs(); - if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Remote cache : Found {} out of {} needed blobs in {}", - ExistsResult.ExistingBlobs.size(), - BlobHashes.size(), - NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); - } + } + ExistsResult.ElapsedTimeMs = Timer.GetElapsedTimeMs(); + if (!ExistsResult.ExistingBlobs.empty() && !m_Options.IsQuiet) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Remote cache : Found {} out of {} needed blobs in {}", + ExistsResult.ExistingBlobs.size(), + BlobHashes.size(), + NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); } } - std::vector BlockRangeWorks; - std::vector FullBlockWorks; + std::vector BlockPartialDownloadModes; + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off) { - Stopwatch Timer; - - std::vector PartialBlockIndexes; - - for (uint32_t BlockIndex : FetchBlockIndexes) + BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::All) + { + BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::On); + } + else + { + BlockPartialDownloadModes.reserve(m_BlockDescriptions.size()); + for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - - const std::vector BlockChunkIndexNeeded = std::move(AllBlockChunkIndexNeeded[BlockIndex]); - if (!BlockChunkIndexNeeded.empty()) + const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) { - bool WantsToDoPartialBlockDownload = BlockChunkIndexNeeded.size() < BlockDescription.ChunkRawHashes.size(); - bool CanDoPartialBlockDownload = - (BlockDescription.HeaderSize > 0) && - (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size()); - - bool AllowedToDoPartialRequest = false; - bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); - switch (m_Options.PartialBlockRequestMode) - { - case EPartialBlockRequestMode::Off: - break; - case EPartialBlockRequestMode::ZenCacheOnly: - AllowedToDoPartialRequest = BlockExistInCache; - break; - case EPartialBlockRequestMode::Mixed: - case EPartialBlockRequestMode::All: - AllowedToDoPartialRequest = true; - break; - default: - ZEN_ASSERT(false); - break; - } + BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::On + : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) + { + BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::On + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); + } + } + } + ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size()); - const uint32_t ChunkStartOffsetInBlock = - gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + ChunkBlockAnalyser::BlockResult PartialBlocks = + BlockAnalyser.CalculatePartialBlockDownloads(NeededBlocks, BlockPartialDownloadModes); - const uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(ChunkStartOffsetInBlock)); + struct LooseChunkHashWorkData + { + std::vector ChunkTargetPtrs; + uint32_t RemoteChunkIndex = (uint32_t)-1; + }; - if (AllowedToDoPartialRequest && WantsToDoPartialBlockDownload && CanDoPartialBlockDownload) - { - ZEN_TRACE_CPU("PartialBlockAnalysis"); - - bool LimitToSingleRange = - BlockExistInCache ? false : m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed; - uint64_t TotalWantedChunksSize = 0; - std::optional> MaybeBlockRanges = - CalculateBlockRanges(BlockIndex, - BlockDescription, - BlockChunkIndexNeeded, - LimitToSingleRange, - ChunkStartOffsetInBlock, - TotalBlockSize, - TotalWantedChunksSize); - ZEN_ASSERT(TotalWantedChunksSize <= TotalBlockSize); - - if (MaybeBlockRanges.has_value()) - { - const std::vector& BlockRanges = MaybeBlockRanges.value(); - ZEN_ASSERT(!BlockRanges.empty()); - BlockRangeWorks.insert(BlockRangeWorks.end(), BlockRanges.begin(), BlockRanges.end()); - TotalRequestCount += BlockRanges.size(); - TotalPartWriteCount += BlockRanges.size(); - - uint64_t RequestedSize = std::accumulate( - BlockRanges.begin(), - BlockRanges.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - PartialBlockIndexes.push_back(BlockIndex); - - if (RequestedSize > TotalWantedChunksSize) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})", - BlockChunkIndexNeeded.size(), - NiceBytes(RequestedSize), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - BlockRanges.size(), - NiceBytes(RequestedSize - TotalWantedChunksSize)); - } - } - } - else - { - FullBlockWorks.push_back(BlockIndex); - TotalRequestCount++; - TotalPartWriteCount++; - } - } - else - { - FullBlockWorks.push_back(BlockIndex); - TotalRequestCount++; - TotalPartWriteCount++; - } - } - } + TotalRequestCount += NeededLooseChunkIndexes.size(); + TotalPartWriteCount += NeededLooseChunkIndexes.size(); + TotalRequestCount += PartialBlocks.BlockRanges.size(); + TotalPartWriteCount += PartialBlocks.BlockRanges.size(); + TotalRequestCount += PartialBlocks.FullBlockIndexes.size(); + TotalPartWriteCount += PartialBlocks.FullBlockIndexes.size(); - if (!PartialBlockIndexes.empty()) - { - uint64_t TotalFullBlockRequestBytes = 0; - for (uint32_t BlockIndex : FullBlockWorks) - { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - uint32_t CurrentOffset = - gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + std::vector LooseChunkHashWorks; + for (uint32_t LooseChunkIndex : NeededLooseChunkIndexes) + { + const IoHash& ChunkHash = m_LooseChunkHashes[LooseChunkIndex]; + auto RemoteChunkIndexIt = m_RemoteLookup.ChunkHashToChunkIndex.find(ChunkHash); + ZEN_ASSERT(RemoteChunkIndexIt != m_RemoteLookup.ChunkHashToChunkIndex.end()); + const uint32_t RemoteChunkIndex = RemoteChunkIndexIt->second; - TotalFullBlockRequestBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(CurrentOffset)); - } + std::vector ChunkTargetPtrs = + GetRemainingChunkTargets(SequenceIndexChunksLeftToWriteCounters, RemoteChunkIndex); - uint64_t TotalPartialBlockBytes = 0; - for (uint32_t BlockIndex : PartialBlockIndexes) - { - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - uint32_t CurrentOffset = - gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + ZEN_ASSERT(!ChunkTargetPtrs.empty()); + LooseChunkHashWorks.push_back( + LooseChunkHashWorkData{.ChunkTargetPtrs = ChunkTargetPtrs, .RemoteChunkIndex = RemoteChunkIndex}); + } - TotalPartialBlockBytes += std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), - BlockDescription.ChunkCompressedLengths.end(), - std::uint64_t(CurrentOffset)); - } + ZEN_TRACE_CPU("WriteChunks"); - uint64_t NonPartialTotalBlockBytes = TotalFullBlockRequestBytes + TotalPartialBlockBytes; + m_LogOutput.SetLogOperationProgress((uint32_t)TaskSteps::WriteChunks, (uint32_t)TaskSteps::StepCount); - const uint64_t TotalPartialBlockRequestBytes = - std::accumulate(BlockRangeWorks.begin(), - BlockRangeWorks.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - uint64_t TotalExtraPartialBlocksRequests = BlockRangeWorks.size() - PartialBlockIndexes.size(); + Stopwatch WriteTimer; - uint64_t TotalSavedBlocksSize = TotalPartialBlockBytes - TotalPartialBlockRequestBytes; - double SavedSizePercent = (TotalSavedBlocksSize * 100.0) / NonPartialTotalBlockBytes; + FilteredRate FilteredDownloadedBytesPerSecond; + FilteredRate FilteredWrittenBytesPerSecond; - if (!m_Options.IsQuiet) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Analysis of partial block requests saves download of {} out of {} ({:.1f}%) using {} extra " - "requests. Completed in {}", - NiceBytes(TotalSavedBlocksSize), - NiceBytes(NonPartialTotalBlockBytes), - SavedSizePercent, - TotalExtraPartialBlocksRequests, - NiceTimeSpanMs(ExistsResult.ElapsedTimeMs)); - } - } - } + std::unique_ptr WriteProgressBarPtr( + m_LogOutput.CreateProgressBar(m_Options.PrimeCacheOnly ? "Downloading" : "Writing")); + OperationLogOutput::ProgressBar& WriteProgressBar(*WriteProgressBarPtr); + ParallelWork Work(m_AbortFlag, m_PauseFlag, WorkerThreadPool::EMode::EnableBacklog); + + TotalPartWriteCount += CopyChunkDatas.size(); + TotalPartWriteCount += ScavengedSequenceCopyOperations.size(); BufferedWriteFileCache WriteCache; @@ -1465,13 +1335,23 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) }); } - for (size_t BlockRangeIndex = 0; BlockRangeIndex < BlockRangeWorks.size(); BlockRangeIndex++) + for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocks.BlockRanges.size();) { ZEN_ASSERT(!m_Options.PrimeCacheOnly); if (m_AbortFlag) { break; } + + size_t RangeCount = 1; + size_t RangesLeft = PartialBlocks.BlockRanges.size() - BlockRangeIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; + while (RangeCount < RangesLeft && + CurrentBlockRange.BlockIndex == PartialBlocks.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) + { + RangeCount++; + } + Work.ScheduleWork( m_NetworkPool, [this, @@ -1485,119 +1365,127 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) TotalPartWriteCount, &FilteredWrittenBytesPerSecond, &Work, - &BlockRangeWorks, - BlockRangeIndex](std::atomic&) { + &PartialBlocks, + BlockRangeStartIndex = BlockRangeIndex, + RangeCount](std::atomic&) { if (!m_AbortFlag) { - ZEN_TRACE_CPU("Async_GetPartialBlock"); - - const BlockRangeDescriptor& BlockRange = BlockRangeWorks[BlockRangeIndex]; + ZEN_TRACE_CPU("Async_GetPartialBlockRanges"); FilteredDownloadedBytesPerSecond.Start(); - DownloadPartialBlock( - BlockRange, - ExistsResult, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalRequestCount, - TotalPartWriteCount, - &FilteredDownloadedBytesPerSecond, - &FilteredWrittenBytesPerSecond, - &BlockRange](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath) { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - &BlockRange, - BlockChunkPath = std::filesystem::path(OnDiskPath), - BlockPartialBuffer = std::move(InMemoryBuffer)](std::atomic&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WritePartialBlock"); + for (size_t BlockRangeIndex = BlockRangeStartIndex; BlockRangeIndex < BlockRangeStartIndex + RangeCount; + BlockRangeIndex++) + { + ZEN_TRACE_CPU("GetPartialBlock"); - const uint32_t BlockIndex = BlockRange.BlockIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + DownloadPartialBlock( + BlockRange, + ExistsResult, + [this, + &RemoteChunkIndexNeedsCopyFromSourceFlags, + &SequenceIndexChunksLeftToWriteCounters, + &WritePartsComplete, + &WriteCache, + &Work, + TotalRequestCount, + TotalPartWriteCount, + &FilteredDownloadedBytesPerSecond, + &FilteredWrittenBytesPerSecond, + &BlockRange](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath) { + if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } - if (BlockChunkPath.empty()) - { - ZEN_ASSERT(BlockPartialBuffer); - } - else + if (!m_AbortFlag) + { + Work.ScheduleWork( + m_IOWorkerPool, + [this, + &RemoteChunkIndexNeedsCopyFromSourceFlags, + &SequenceIndexChunksLeftToWriteCounters, + &WritePartsComplete, + &WriteCache, + &Work, + TotalPartWriteCount, + &FilteredWrittenBytesPerSecond, + &BlockRange, + BlockChunkPath = std::filesystem::path(OnDiskPath), + BlockPartialBuffer = std::move(InMemoryBuffer)](std::atomic&) mutable { + if (!m_AbortFlag) { - ZEN_ASSERT(!BlockPartialBuffer); - BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); - if (!BlockPartialBuffer) + ZEN_TRACE_CPU("Async_WritePartialBlock"); + + const uint32_t BlockIndex = BlockRange.BlockIndex; + + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + if (BlockChunkPath.empty()) { - throw std::runtime_error( - fmt::format("Could not open downloaded block {} from {}", - BlockDescription.BlockHash, - BlockChunkPath)); + ZEN_ASSERT(BlockPartialBuffer); + } + else + { + ZEN_ASSERT(!BlockPartialBuffer); + BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (!BlockPartialBuffer) + { + throw std::runtime_error( + fmt::format("Could not open downloaded block {} from {}", + BlockDescription.BlockHash, + BlockChunkPath)); + } } - } - - FilteredWrittenBytesPerSecond.Start(); - if (!WritePartialBlockChunksToCache( - BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockPartialBuffer)), - BlockRange.ChunkBlockIndexStart, - BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount - 1, - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) - { - std::error_code DummyEc; - RemoveFile(BlockChunkPath, DummyEc); - throw std::runtime_error( - fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); - } + FilteredWrittenBytesPerSecond.Start(); + + if (!WritePartialBlockChunksToCache( + BlockDescription, + SequenceIndexChunksLeftToWriteCounters, + Work, + CompositeBuffer(std::move(BlockPartialBuffer)), + BlockRange.ChunkBlockIndexStart, + BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount - 1, + RemoteChunkIndexNeedsCopyFromSourceFlags, + WriteCache)) + { + std::error_code DummyEc; + RemoveFile(BlockChunkPath, DummyEc); + throw std::runtime_error( + fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); + } - std::error_code Ec = TryRemoveFile(BlockChunkPath); - if (Ec) - { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockChunkPath, - Ec.value(), - Ec.message()); - } + std::error_code Ec = TryRemoveFile(BlockChunkPath); + if (Ec) + { + ZEN_OPERATION_LOG_DEBUG(m_LogOutput, + "Failed removing file '{}', reason: ({}) {}", + BlockChunkPath, + Ec.value(), + Ec.message()); + } - WritePartsComplete++; - if (WritePartsComplete == TotalPartWriteCount) - { - FilteredWrittenBytesPerSecond.Stop(); + WritePartsComplete++; + if (WritePartsComplete == TotalPartWriteCount) + { + FilteredWrittenBytesPerSecond.Stop(); + } } - } - }, - OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog - : WorkerThreadPool::EMode::EnableBacklog); - } - }); + }, + OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog + : WorkerThreadPool::EMode::EnableBacklog); + } + }); + } } }); + BlockRangeIndex += RangeCount; } - for (uint32_t BlockIndex : FullBlockWorks) + for (uint32_t BlockIndex : PartialBlocks.FullBlockIndexes) { if (m_AbortFlag) { @@ -3282,271 +3170,9 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde } } -BuildsOperationUpdateFolder::BlockRangeDescriptor -BuildsOperationUpdateFolder::MergeBlockRanges(std::span Ranges) -{ - ZEN_ASSERT(Ranges.size() > 1); - const BlockRangeDescriptor& First = Ranges.front(); - const BlockRangeDescriptor& Last = Ranges.back(); - - return BlockRangeDescriptor{.BlockIndex = First.BlockIndex, - .RangeStart = First.RangeStart, - .RangeLength = Last.RangeStart + Last.RangeLength - First.RangeStart, - .ChunkBlockIndexStart = First.ChunkBlockIndexStart, - .ChunkBlockIndexCount = Last.ChunkBlockIndexStart + Last.ChunkBlockIndexCount - First.ChunkBlockIndexStart}; -} - -std::optional> -BuildsOperationUpdateFolder::MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, const BlockRangeDescriptor& Range) -{ - if (Range.RangeLength == TotalBlockSize) - { - return {}; - } - else - { - return std::vector{Range}; - } -}; - -const BuildsOperationUpdateFolder::BlockRangeLimit* -BuildsOperationUpdateFolder::GetBlockRangeLimitForRange(std::span Limits, - uint64_t TotalBlockSize, - std::span Ranges) -{ - if (Ranges.size() > 1) - { - const std::uint64_t WantedSize = - std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { - return Current + Range.RangeLength; - }); - - const double RangeRequestedPercent = (WantedSize * 100.0) / TotalBlockSize; - - for (const BlockRangeLimit& Limit : Limits) - { - if (RangeRequestedPercent >= Limit.SizePercent && Ranges.size() > Limit.MaxRangeCount) - { - return &Limit; - } - } - } - return nullptr; -}; - -std::vector -BuildsOperationUpdateFolder::CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std::span BlockRanges) -{ - ZEN_ASSERT(BlockRanges.size() > 1); - std::vector CollapsedBlockRanges; - - auto BlockRangesIt = BlockRanges.begin(); - CollapsedBlockRanges.push_back(*BlockRangesIt++); - for (; BlockRangesIt != BlockRanges.end(); BlockRangesIt++) - { - BlockRangeDescriptor& LastRange = CollapsedBlockRanges.back(); - - const uint64_t BothRangeSize = BlockRangesIt->RangeLength + LastRange.RangeLength; - - const uint64_t Gap = BlockRangesIt->RangeStart - (LastRange.RangeStart + LastRange.RangeLength); - if (Gap <= Max(BothRangeSize / 16, AlwaysAcceptableGap)) - { - LastRange.ChunkBlockIndexCount = - (BlockRangesIt->ChunkBlockIndexStart + BlockRangesIt->ChunkBlockIndexCount) - LastRange.ChunkBlockIndexStart; - LastRange.RangeLength = (BlockRangesIt->RangeStart + BlockRangesIt->RangeLength) - LastRange.RangeStart; - } - else - { - CollapsedBlockRanges.push_back(*BlockRangesIt); - } - } - - return CollapsedBlockRanges; -}; - -uint64_t -BuildsOperationUpdateFolder::CalculateNextGap(std::span BlockRanges) -{ - ZEN_ASSERT(BlockRanges.size() > 1); - uint64_t AcceptableGap = (uint64_t)-1; - for (size_t RangeIndex = 0; RangeIndex < BlockRanges.size() - 1; RangeIndex++) - { - const BlockRangeDescriptor& Range = BlockRanges[RangeIndex]; - const BlockRangeDescriptor& NextRange = BlockRanges[RangeIndex + 1]; - - const uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength); - AcceptableGap = Min(Gap, AcceptableGap); - } - AcceptableGap = RoundUp(AcceptableGap, 16u * 1024u); - return AcceptableGap; -}; - -std::optional> -BuildsOperationUpdateFolder::CalculateBlockRanges(uint32_t BlockIndex, - const ChunkBlockDescription& BlockDescription, - std::span BlockChunkIndexNeeded, - bool LimitToSingleRange, - const uint64_t ChunkStartOffsetInBlock, - const uint64_t TotalBlockSize, - uint64_t& OutTotalWantedChunksSize) -{ - ZEN_TRACE_CPU("CalculateBlockRanges"); - - std::vector BlockRanges; - { - uint64_t CurrentOffset = ChunkStartOffsetInBlock; - uint32_t ChunkBlockIndex = 0; - uint32_t NeedBlockChunkIndexOffset = 0; - BlockRangeDescriptor NextRange{.BlockIndex = BlockIndex}; - while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size()) - { - const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; - if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) - { - if (NextRange.RangeLength > 0) - { - BlockRanges.push_back(NextRange); - NextRange = {.BlockIndex = BlockIndex}; - } - ChunkBlockIndex++; - CurrentOffset += ChunkCompressedLength; - } - else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) - { - if (NextRange.RangeLength == 0) - { - NextRange.RangeStart = CurrentOffset; - NextRange.ChunkBlockIndexStart = ChunkBlockIndex; - } - NextRange.RangeLength += ChunkCompressedLength; - NextRange.ChunkBlockIndexCount++; - ChunkBlockIndex++; - CurrentOffset += ChunkCompressedLength; - NeedBlockChunkIndexOffset++; - } - else - { - ZEN_ASSERT(false); - } - } - if (NextRange.RangeLength > 0) - { - BlockRanges.push_back(NextRange); - } - } - ZEN_ASSERT(!BlockRanges.empty()); - - OutTotalWantedChunksSize = - std::accumulate(BlockRanges.begin(), BlockRanges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { - return Current + Range.RangeLength; - }); - - double RangeWantedPercent = (OutTotalWantedChunksSize * 100.0) / TotalBlockSize; - - if (BlockRanges.size() == 1) - { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Range request of {} ({:.2f}%) using single range from block {} ({}) as is", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize)); - } - return BlockRanges; - } - - if (LimitToSingleRange) - { - const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); - if (m_Options.IsVerbose) - { - const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; - const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; - - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) limited to single block range {} ({:.2f}%) wasting " - "{:.2f}% ({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - NiceBytes(MergedRange.RangeLength), - RangeRequestedPercent, - WastedPercent, - NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); - } - return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); - } - - if (RangeWantedPercent > FullBlockRangePercentLimit) - { - const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); - if (m_Options.IsVerbose) - { - const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; - const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; - - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) exceeds {}%. Merged to single block range {} " - "({:.2f}%) wasting {:.2f}% ({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - FullBlockRangePercentLimit, - NiceBytes(MergedRange.RangeLength), - RangeRequestedPercent, - WastedPercent, - NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); - } - return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); - } - - std::vector CollapsedBlockRanges = CollapseBlockRanges(16u * 1024u, BlockRanges); - while (GetBlockRangeLimitForRange(ForceMergeLimits, TotalBlockSize, CollapsedBlockRanges)) - { - CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(CollapsedBlockRanges), CollapsedBlockRanges); - } - - const std::uint64_t WantedCollapsedSize = - std::accumulate(CollapsedBlockRanges.begin(), - CollapsedBlockRanges.end(), - uint64_t(0), - [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - - const double CollapsedRangeRequestedPercent = (WantedCollapsedSize * 100.0) / TotalBlockSize; - - if (m_Options.IsVerbose) - { - const double WastedPercent = ((WantedCollapsedSize - OutTotalWantedChunksSize) * 100.0) / WantedCollapsedSize; - - ZEN_OPERATION_LOG_INFO( - m_LogOutput, - "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) collapsed to {} {:.2f}% using {} ranges wasting {:.2f}% " - "({})", - NiceBytes(OutTotalWantedChunksSize), - RangeWantedPercent, - BlockRanges.size(), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - NiceBytes(WantedCollapsedSize), - CollapsedRangeRequestedPercent, - CollapsedBlockRanges.size(), - WastedPercent, - NiceBytes(WantedCollapsedSize - OutTotalWantedChunksSize)); - } - return CollapsedBlockRanges; -} - void BuildsOperationUpdateFolder::DownloadPartialBlock( - const BlockRangeDescriptor BlockRange, + const ChunkBlockAnalyser::BlockRangeDescriptor BlockRange, const BlobsExistsResult& ExistsResult, std::function&& OnDownloaded) { diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index c4d8653f4..06cedae3f 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -10,18 +10,17 @@ #include +#include #include ZEN_THIRD_PARTY_INCLUDES_START #include +#include ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS # include # include - -# include -# include #endif // ZEN_WITH_TESTS namespace zen { @@ -455,6 +454,537 @@ FindReuseBlocks(OperationLogOutput& Output, return FilteredReuseBlockIndexes; } +ChunkBlockAnalyser::ChunkBlockAnalyser(OperationLogOutput& LogOutput, + std::span BlockDescriptions, + const Options& Options) +: m_LogOutput(LogOutput) +, m_BlockDescriptions(BlockDescriptions) +, m_Options(Options) +{ +} + +std::vector +ChunkBlockAnalyser::GetNeeded(const tsl::robin_map& ChunkHashToChunkIndex, + std::function&& NeedsBlockChunk) +{ + ZEN_TRACE_CPU("ChunkBlockAnalyser::GetNeeded"); + + std::vector Result; + + std::vector ChunkIsNeeded(ChunkHashToChunkIndex.size()); + for (uint32_t ChunkIndex = 0; ChunkIndex < ChunkHashToChunkIndex.size(); ChunkIndex++) + { + ChunkIsNeeded[ChunkIndex] = NeedsBlockChunk(ChunkIndex); + } + + std::vector BlockSlack(m_BlockDescriptions.size(), 0u); + for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + uint64_t BlockUsedSize = 0; + uint64_t BlockSize = 0; + + for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) + { + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end()) + { + const uint32_t RemoteChunkIndex = It->second; + if (ChunkIsNeeded[RemoteChunkIndex]) + { + BlockUsedSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + } + } + BlockSize += BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + } + BlockSlack[BlockIndex] = BlockSize - BlockUsedSize; + } + + std::vector BlockOrder(m_BlockDescriptions.size()); + std::iota(BlockOrder.begin(), BlockOrder.end(), 0); + + std::sort(BlockOrder.begin(), BlockOrder.end(), [&BlockSlack](uint32_t Lhs, uint32_t Rhs) { + return BlockSlack[Lhs] < BlockSlack[Rhs]; + }); + + std::vector ChunkIsPickedUp(ChunkHashToChunkIndex.size(), false); + + for (uint32_t BlockIndex : BlockOrder) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + + std::vector BlockChunkIndexNeeded; + + for (uint32_t ChunkBlockIndex = 0; ChunkBlockIndex < BlockDescription.ChunkRawHashes.size(); ChunkBlockIndex++) + { + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + if (auto It = ChunkHashToChunkIndex.find(ChunkHash); It != ChunkHashToChunkIndex.end()) + { + const uint32_t RemoteChunkIndex = It->second; + if (ChunkIsNeeded[RemoteChunkIndex]) + { + if (!ChunkIsPickedUp[RemoteChunkIndex]) + { + ChunkIsPickedUp[RemoteChunkIndex] = true; + BlockChunkIndexNeeded.push_back(ChunkBlockIndex); + } + } + } + else + { + ZEN_DEBUG("Chunk {} not found in block {}", ChunkHash, BlockDescription.BlockHash); + } + } + + if (!BlockChunkIndexNeeded.empty()) + { + Result.push_back(NeededBlock{.BlockIndex = BlockIndex, .ChunkIndexes = std::move(BlockChunkIndexNeeded)}); + } + } + return Result; +} + +ChunkBlockAnalyser::BlockResult +ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span NeededBlocks, + std::span BlockPartialDownloadModes) +{ + ZEN_TRACE_CPU("ChunkBlockAnalyser::CalculatePartialBlockDownloads"); + + Stopwatch PartialAnalisysTimer; + + ChunkBlockAnalyser::BlockResult Result; + + uint64_t IdealDownloadTotalSize = 0; + uint64_t AllBlocksTotalBlocksSize = 0; + + for (const NeededBlock& NeededBlock : NeededBlocks) + { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[NeededBlock.BlockIndex]; + + std::span BlockChunkIndexNeeded(NeededBlock.ChunkIndexes); + if (!NeededBlock.ChunkIndexes.empty()) + { + bool WantsToDoPartialBlockDownload = NeededBlock.ChunkIndexes.size() < BlockDescription.ChunkRawHashes.size(); + bool CanDoPartialBlockDownload = (BlockDescription.HeaderSize > 0) && + (BlockDescription.ChunkCompressedLengths.size() == BlockDescription.ChunkRawHashes.size()); + + EPartialBlockDownloadMode PartialBlockDownloadMode = BlockPartialDownloadModes[NeededBlock.BlockIndex]; + + const uint32_t ChunkStartOffsetInBlock = + gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + + const uint64_t TotalBlockSize = std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), + BlockDescription.ChunkCompressedLengths.end(), + std::uint64_t(ChunkStartOffsetInBlock)); + + AllBlocksTotalBlocksSize += TotalBlockSize; + + if ((PartialBlockDownloadMode != EPartialBlockDownloadMode::Off) && WantsToDoPartialBlockDownload && CanDoPartialBlockDownload) + { + ZEN_TRACE_CPU("PartialBlockAnalysis"); + + uint64_t TotalWantedChunksSize = 0; + std::optional> MaybeBlockRanges = CalculateBlockRanges(NeededBlock.BlockIndex, + BlockDescription, + NeededBlock.ChunkIndexes, + PartialBlockDownloadMode, + ChunkStartOffsetInBlock, + TotalBlockSize, + TotalWantedChunksSize); + ZEN_ASSERT(TotalWantedChunksSize <= TotalBlockSize); + IdealDownloadTotalSize += TotalWantedChunksSize; + + if (MaybeBlockRanges.has_value()) + { + const std::vector& BlockRanges = MaybeBlockRanges.value(); + ZEN_ASSERT(!BlockRanges.empty()); + + uint64_t RequestedSize = + std::accumulate(BlockRanges.begin(), + BlockRanges.end(), + uint64_t(0), + [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); + + if ((PartialBlockDownloadMode != EPartialBlockDownloadMode::Exact) && ((RequestedSize * 100) / TotalBlockSize) >= 200) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Requesting {} chunks ({}) from block {} ({}) using full block request (extra bytes {})", + NeededBlock.ChunkIndexes.size(), + NiceBytes(RequestedSize), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + NiceBytes(TotalBlockSize - TotalWantedChunksSize)); + } + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + } + else + { + Result.BlockRanges.insert(Result.BlockRanges.end(), BlockRanges.begin(), BlockRanges.end()); + + if (RequestedSize > TotalWantedChunksSize) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})", + NeededBlock.ChunkIndexes.size(), + NiceBytes(RequestedSize), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + BlockRanges.size(), + NiceBytes(RequestedSize - TotalWantedChunksSize)); + } + } + } + } + else + { + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + } + } + else + { + Result.FullBlockIndexes.push_back(NeededBlock.BlockIndex); + IdealDownloadTotalSize += TotalBlockSize; + } + } + } + + if (!Result.BlockRanges.empty() && !m_Options.IsQuiet) + { + tsl::robin_set PartialBlockIndexes; + uint64_t PartialBlocksTotalSize = std::accumulate(Result.BlockRanges.begin(), + Result.BlockRanges.end(), + uint64_t(0u), + [&](uint64_t Current, const BlockRangeDescriptor& Range) { + PartialBlockIndexes.insert(Range.BlockIndex); + return Current + Range.RangeLength; + }); + + uint64_t FullBlocksTotalSize = + std::accumulate(Result.FullBlockIndexes.begin(), + Result.FullBlockIndexes.end(), + uint64_t(0u), + [&](uint64_t Current, uint32_t BlockIndex) { + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + uint32_t CurrentOffset = + gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + + return Current + std::accumulate(BlockDescription.ChunkCompressedLengths.begin(), + BlockDescription.ChunkCompressedLengths.end(), + std::uint64_t(CurrentOffset)); + }); + + uint64_t PartialBlockRequestCount = Result.BlockRanges.size(); + uint64_t PartialBlockCount = PartialBlockIndexes.size(); + + uint64_t TotalExtraPartialBlocksRequestCount = PartialBlockRequestCount - PartialBlockCount; + uint64_t ActualPartialDownloadTotalSize = FullBlocksTotalSize + PartialBlocksTotalSize; + + uint64_t IdealSkippedSize = AllBlocksTotalBlocksSize - IdealDownloadTotalSize; + uint64_t ActualSkippedSize = AllBlocksTotalBlocksSize - ActualPartialDownloadTotalSize; + + double PercentOfIdealPartialSkippedSize = (ActualSkippedSize * 100.0) / IdealSkippedSize; + + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Analysis of partial block requests saves download of {} out of {}, {:.1f}% of possible {} using {} extra " + "requests. Completed in {}", + NiceBytes(ActualSkippedSize), + NiceBytes(AllBlocksTotalBlocksSize), + PercentOfIdealPartialSkippedSize, + NiceBytes(IdealSkippedSize), + TotalExtraPartialBlocksRequestCount, + NiceTimeSpanMs(PartialAnalisysTimer.GetElapsedTimeMs())); + } + + return Result; +} + +ChunkBlockAnalyser::BlockRangeDescriptor +ChunkBlockAnalyser::MergeBlockRanges(std::span Ranges) +{ + ZEN_ASSERT(Ranges.size() > 1); + const BlockRangeDescriptor& First = Ranges.front(); + const BlockRangeDescriptor& Last = Ranges.back(); + + return BlockRangeDescriptor{.BlockIndex = First.BlockIndex, + .RangeStart = First.RangeStart, + .RangeLength = Last.RangeStart + Last.RangeLength - First.RangeStart, + .ChunkBlockIndexStart = First.ChunkBlockIndexStart, + .ChunkBlockIndexCount = Last.ChunkBlockIndexStart + Last.ChunkBlockIndexCount - First.ChunkBlockIndexStart}; +} + +std::optional> +ChunkBlockAnalyser::MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, const BlockRangeDescriptor& Range) +{ + if (Range.RangeLength == TotalBlockSize) + { + return {}; + } + else + { + return std::vector{Range}; + } +}; + +const ChunkBlockAnalyser::BlockRangeLimit* +ChunkBlockAnalyser::GetBlockRangeLimitForRange(std::span Limits, + uint64_t TotalBlockSize, + std::span Ranges) +{ + if (Ranges.size() > 1) + { + const std::uint64_t WantedSize = + std::accumulate(Ranges.begin(), Ranges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { + return Current + Range.RangeLength; + }); + + const double RangeRequestedPercent = (WantedSize * 100.0) / TotalBlockSize; + + for (const BlockRangeLimit& Limit : Limits) + { + if (RangeRequestedPercent >= Limit.SizePercent && Ranges.size() > Limit.MaxRangeCount) + { + return &Limit; + } + } + } + return nullptr; +}; + +std::vector +ChunkBlockAnalyser::CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std::span BlockRanges) +{ + ZEN_ASSERT(BlockRanges.size() > 1); + std::vector CollapsedBlockRanges; + + auto BlockRangesIt = BlockRanges.begin(); + CollapsedBlockRanges.push_back(*BlockRangesIt++); + for (; BlockRangesIt != BlockRanges.end(); BlockRangesIt++) + { + BlockRangeDescriptor& LastRange = CollapsedBlockRanges.back(); + + const uint64_t BothRangeSize = BlockRangesIt->RangeLength + LastRange.RangeLength; + + const uint64_t Gap = BlockRangesIt->RangeStart - (LastRange.RangeStart + LastRange.RangeLength); + if (Gap <= Max(BothRangeSize / 16, AlwaysAcceptableGap)) + { + LastRange.ChunkBlockIndexCount = + (BlockRangesIt->ChunkBlockIndexStart + BlockRangesIt->ChunkBlockIndexCount) - LastRange.ChunkBlockIndexStart; + LastRange.RangeLength = (BlockRangesIt->RangeStart + BlockRangesIt->RangeLength) - LastRange.RangeStart; + } + else + { + CollapsedBlockRanges.push_back(*BlockRangesIt); + } + } + + return CollapsedBlockRanges; +}; + +uint64_t +ChunkBlockAnalyser::CalculateNextGap(std::span BlockRanges) +{ + ZEN_ASSERT(BlockRanges.size() > 1); + uint64_t AcceptableGap = (uint64_t)-1; + for (size_t RangeIndex = 0; RangeIndex < BlockRanges.size() - 1; RangeIndex++) + { + const BlockRangeDescriptor& Range = BlockRanges[RangeIndex]; + const BlockRangeDescriptor& NextRange = BlockRanges[RangeIndex + 1]; + + const uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength); + AcceptableGap = Min(Gap, AcceptableGap); + } + AcceptableGap = RoundUp(AcceptableGap, 16u * 1024u); + return AcceptableGap; +}; + +std::optional> +ChunkBlockAnalyser::CalculateBlockRanges(uint32_t BlockIndex, + const ChunkBlockDescription& BlockDescription, + std::span BlockChunkIndexNeeded, + EPartialBlockDownloadMode PartialBlockDownloadMode, + const uint64_t ChunkStartOffsetInBlock, + const uint64_t TotalBlockSize, + uint64_t& OutTotalWantedChunksSize) +{ + ZEN_TRACE_CPU("CalculateBlockRanges"); + + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Off) + { + return {}; + } + + std::vector BlockRanges; + { + uint64_t CurrentOffset = ChunkStartOffsetInBlock; + uint32_t ChunkBlockIndex = 0; + uint32_t NeedBlockChunkIndexOffset = 0; + BlockRangeDescriptor NextRange{.BlockIndex = BlockIndex}; + while (NeedBlockChunkIndexOffset < BlockChunkIndexNeeded.size() && ChunkBlockIndex < BlockDescription.ChunkRawHashes.size()) + { + const uint32_t ChunkCompressedLength = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + if (ChunkBlockIndex < BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) + { + if (NextRange.RangeLength > 0) + { + BlockRanges.push_back(NextRange); + NextRange = {.BlockIndex = BlockIndex}; + } + ChunkBlockIndex++; + CurrentOffset += ChunkCompressedLength; + } + else if (ChunkBlockIndex == BlockChunkIndexNeeded[NeedBlockChunkIndexOffset]) + { + if (NextRange.RangeLength == 0) + { + NextRange.RangeStart = CurrentOffset; + NextRange.ChunkBlockIndexStart = ChunkBlockIndex; + } + NextRange.RangeLength += ChunkCompressedLength; + NextRange.ChunkBlockIndexCount++; + ChunkBlockIndex++; + CurrentOffset += ChunkCompressedLength; + NeedBlockChunkIndexOffset++; + } + else + { + ZEN_ASSERT(false); + } + } + if (NextRange.RangeLength > 0) + { + BlockRanges.push_back(NextRange); + } + } + ZEN_ASSERT(!BlockRanges.empty()); + + OutTotalWantedChunksSize = + std::accumulate(BlockRanges.begin(), BlockRanges.end(), uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { + return Current + Range.RangeLength; + }); + + double RangeWantedPercent = (OutTotalWantedChunksSize * 100.0) / TotalBlockSize; + + if (BlockRanges.size() == 1) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Range request of {} ({:.2f}%) using single range from block {} ({}) as is", + NiceBytes(OutTotalWantedChunksSize), + RangeWantedPercent, + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize)); + } + return BlockRanges; + } + + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::Exact) + { + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Range request of {} ({:.2f}%) using {} ranges from block {} ({})", + NiceBytes(OutTotalWantedChunksSize), + RangeWantedPercent, + BlockRanges.size(), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize)); + } + return BlockRanges; + } + + if (PartialBlockDownloadMode == EPartialBlockDownloadMode::SingleRange) + { + const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); + if (m_Options.IsVerbose) + { + const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; + const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; + + ZEN_OPERATION_LOG_INFO( + m_LogOutput, + "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) limited to single block range {} ({:.2f}%) wasting " + "{:.2f}% ({})", + NiceBytes(OutTotalWantedChunksSize), + RangeWantedPercent, + BlockRanges.size(), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + NiceBytes(MergedRange.RangeLength), + RangeRequestedPercent, + WastedPercent, + NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); + } + return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); + } + + if (RangeWantedPercent > FullBlockRangePercentLimit) + { + const BlockRangeDescriptor MergedRange = MergeBlockRanges(BlockRanges); + if (m_Options.IsVerbose) + { + const double RangeRequestedPercent = (MergedRange.RangeLength * 100.0) / TotalBlockSize; + const double WastedPercent = ((MergedRange.RangeLength - OutTotalWantedChunksSize) * 100.0) / MergedRange.RangeLength; + + ZEN_OPERATION_LOG_INFO( + m_LogOutput, + "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) exceeds {}%. Merged to single block range {} " + "({:.2f}%) wasting {:.2f}% ({})", + NiceBytes(OutTotalWantedChunksSize), + RangeWantedPercent, + BlockRanges.size(), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + FullBlockRangePercentLimit, + NiceBytes(MergedRange.RangeLength), + RangeRequestedPercent, + WastedPercent, + NiceBytes(MergedRange.RangeLength - OutTotalWantedChunksSize)); + } + return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); + } + + std::vector CollapsedBlockRanges = CollapseBlockRanges(16u * 1024u, BlockRanges); + while (GetBlockRangeLimitForRange(ForceMergeLimits, TotalBlockSize, CollapsedBlockRanges)) + { + CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(CollapsedBlockRanges), CollapsedBlockRanges); + } + + const std::uint64_t WantedCollapsedSize = + std::accumulate(CollapsedBlockRanges.begin(), + CollapsedBlockRanges.end(), + uint64_t(0), + [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); + + const double CollapsedRangeRequestedPercent = (WantedCollapsedSize * 100.0) / TotalBlockSize; + + if (m_Options.IsVerbose) + { + const double WastedPercent = ((WantedCollapsedSize - OutTotalWantedChunksSize) * 100.0) / WantedCollapsedSize; + + ZEN_OPERATION_LOG_INFO( + m_LogOutput, + "Range request of {} ({:.2f}%) using {} ranges from block {} ({}) collapsed to {} {:.2f}% using {} ranges wasting {:.2f}% " + "({})", + NiceBytes(OutTotalWantedChunksSize), + RangeWantedPercent, + BlockRanges.size(), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + NiceBytes(WantedCollapsedSize), + CollapsedRangeRequestedPercent, + CollapsedBlockRanges.size(), + WastedPercent, + NiceBytes(WantedCollapsedSize - OutTotalWantedChunksSize)); + } + return CollapsedBlockRanges; +} + #if ZEN_WITH_TESTS namespace testutils { @@ -476,7 +1006,7 @@ namespace testutils { } // namespace testutils -TEST_CASE("project.store.block") +TEST_CASE("chunkblock.block") { using namespace std::literals; using namespace testutils; @@ -504,7 +1034,7 @@ TEST_CASE("project.store.block") HeaderSize)); } -TEST_CASE("project.store.reuseblocks") +TEST_CASE("chunkblock.reuseblocks") { using namespace std::literals; using namespace testutils; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 9e5bf8d91..6800444e0 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -218,33 +219,6 @@ private: uint64_t ElapsedTimeMs = 0; }; - struct BlockRangeDescriptor - { - uint32_t BlockIndex = (uint32_t)-1; - uint64_t RangeStart = 0; - uint64_t RangeLength = 0; - uint32_t ChunkBlockIndexStart = 0; - uint32_t ChunkBlockIndexCount = 0; - }; - - struct BlockRangeLimit - { - uint16_t SizePercent; - uint16_t MaxRangeCount; - }; - - static constexpr uint16_t FullBlockRangePercentLimit = 95; - - static constexpr BuildsOperationUpdateFolder::BlockRangeLimit ForceMergeLimits[] = { - {.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1}, - {.SizePercent = 90, .MaxRangeCount = 2}, - {.SizePercent = 85, .MaxRangeCount = 8}, - {.SizePercent = 80, .MaxRangeCount = 16}, - {.SizePercent = 70, .MaxRangeCount = 32}, - {.SizePercent = 60, .MaxRangeCount = 48}, - {.SizePercent = 2, .MaxRangeCount = 56}, - {.SizePercent = 0, .MaxRangeCount = 64}}; - void ScanCacheFolder(tsl::robin_map& OutCachedChunkHashesFound, tsl::robin_map& OutCachedSequenceHashesFound); void ScanTempBlocksFolder(tsl::robin_map& OutCachedBlocksFound); @@ -299,25 +273,9 @@ private: ParallelWork& Work, std::function&& OnDownloaded); - BlockRangeDescriptor MergeBlockRanges(std::span Ranges); - std::optional> MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, - const BlockRangeDescriptor& Range); - const BlockRangeLimit* GetBlockRangeLimitForRange(std::span Limits, - uint64_t TotalBlockSize, - std::span Ranges); - std::vector CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, - std::span BlockRanges); - uint64_t CalculateNextGap(std::span BlockRanges); - std::optional> CalculateBlockRanges(uint32_t BlockIndex, - const ChunkBlockDescription& BlockDescription, - std::span BlockChunkIndexNeeded, - bool LimitToSingleRange, - const uint64_t ChunkStartOffsetInBlock, - const uint64_t TotalBlockSize, - uint64_t& OutTotalWantedChunksSize); - void DownloadPartialBlock(const BlockRangeDescriptor BlockRange, - const BlobsExistsResult& ExistsResult, - std::function&& OnDownloaded); + void DownloadPartialBlock(const ChunkBlockAnalyser::BlockRangeDescriptor BlockRange, + const BlobsExistsResult& ExistsResult, + std::function&& OnDownloaded); std::vector WriteLocalChunkToCache(CloneQueryInterface* CloneQuery, const CopyChunkData& CopyData, diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index d339b0f94..57710fcf5 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -7,6 +7,10 @@ #include #include +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + #include #include @@ -73,6 +77,96 @@ std::vector FindReuseBlocks(OperationLogOutput& Output, std::span ChunkIndexes, std::vector& OutUnusedChunkIndexes); +class ChunkBlockAnalyser +{ +public: + struct Options + { + bool IsQuiet = false; + bool IsVerbose = false; + }; + + ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span BlockDescriptions, const Options& Options); + + struct BlockRangeDescriptor + { + uint32_t BlockIndex = (uint32_t)-1; + uint64_t RangeStart = 0; + uint64_t RangeLength = 0; + uint32_t ChunkBlockIndexStart = 0; + uint32_t ChunkBlockIndexCount = 0; + }; + + struct NeededBlock + { + uint32_t BlockIndex; + std::vector ChunkIndexes; + }; + + std::vector GetNeeded(const tsl::robin_map& ChunkHashToChunkIndex, + std::function&& NeedsBlockChunk); + + enum EPartialBlockDownloadMode + { + Off, + SingleRange, + On, + Exact + }; + + struct BlockResult + { + std::vector BlockRanges; + std::vector FullBlockIndexes; + }; + + BlockResult CalculatePartialBlockDownloads(std::span NeededBlocks, + std::span BlockPartialDownloadModes); + +private: + struct BlockRangeLimit + { + uint16_t SizePercent; + uint16_t MaxRangeCount; + }; + + static constexpr uint16_t FullBlockRangePercentLimit = 95; + + static constexpr BlockRangeLimit ForceMergeLimits[] = {{.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1}, + {.SizePercent = 90, .MaxRangeCount = 2}, + {.SizePercent = 85, .MaxRangeCount = 8}, + {.SizePercent = 80, .MaxRangeCount = 16}, + {.SizePercent = 75, .MaxRangeCount = 32}, + {.SizePercent = 70, .MaxRangeCount = 48}, + {.SizePercent = 4, .MaxRangeCount = 82}, + {.SizePercent = 0, .MaxRangeCount = 96}}; + + BlockRangeDescriptor MergeBlockRanges(std::span Ranges); + std::optional> MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, + const BlockRangeDescriptor& Range); + const BlockRangeLimit* GetBlockRangeLimitForRange(std::span Limits, + uint64_t TotalBlockSize, + std::span Ranges); + std::vector CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, + std::span BlockRanges); + uint64_t CalculateNextGap(std::span BlockRanges); + std::optional> CalculateBlockRanges(uint32_t BlockIndex, + const ChunkBlockDescription& BlockDescription, + std::span BlockChunkIndexNeeded, + EPartialBlockDownloadMode PartialBlockDownloadMode, + const uint64_t ChunkStartOffsetInBlock, + const uint64_t TotalBlockSize, + uint64_t& OutTotalWantedChunksSize); + + OperationLogOutput& m_LogOutput; + const std::span m_BlockDescriptions; + const Options m_Options; +}; + +#if ZEN_WITH_TESTS + +class CbWriter; void chunkblock_forcelink(); +#endif // ZEN_WITH_TESTS } // namespace zen -- cgit v1.2.3 From 73f3eb4feedf3bca0ddc832a89a05f09813c6858 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 23 Feb 2026 11:08:24 +0100 Subject: implemented base64 decoding (#777) Co-authored-by: Stefan Boberg --- src/zencore/base64.cpp | 192 ++++++++++++++++++++++++++++++++++- src/zencore/include/zencore/base64.h | 4 + 2 files changed, 194 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp index 1f56ee6c3..fdf5f2d66 100644 --- a/src/zencore/base64.cpp +++ b/src/zencore/base64.cpp @@ -1,6 +1,10 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include +#include +#include + +#include namespace zen { @@ -11,7 +15,6 @@ static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; /** The table used to convert an ascii character into a 6 bit value */ -#if 0 static const uint8_t DecodingAlphabet[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f @@ -30,7 +33,6 @@ static const uint8_t DecodingAlphabet[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff }; -#endif // 0 template uint32_t @@ -104,4 +106,190 @@ Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest) template uint32_t Base64::Encode(const uint8_t* Source, uint32_t Length, char* Dest); template uint32_t Base64::Encode(const uint8_t* Source, uint32_t Length, wchar_t* Dest); +template +bool +Base64::Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength) +{ + // Length must be a multiple of 4 + if (Length % 4 != 0) + { + OutLength = 0; + return false; + } + + uint8_t* DecodedBytes = Dest; + + // Process 4 encoded characters at a time, producing 3 decoded bytes + while (Length > 0) + { + // Count padding characters at the end + uint32_t PadCount = 0; + if (Source[3] == '=') + { + PadCount++; + if (Source[2] == '=') + { + PadCount++; + } + } + + // Look up each character in the decoding table + uint8_t A = DecodingAlphabet[static_cast(Source[0])]; + uint8_t B = DecodingAlphabet[static_cast(Source[1])]; + uint8_t C = (PadCount >= 2) ? 0 : DecodingAlphabet[static_cast(Source[2])]; + uint8_t D = (PadCount >= 1) ? 0 : DecodingAlphabet[static_cast(Source[3])]; + + // Check for invalid characters (0xFF means not in the base64 alphabet) + if (A == 0xFF || B == 0xFF || C == 0xFF || D == 0xFF) + { + OutLength = 0; + return false; + } + + // Reconstruct the 24-bit value from 4 6-bit chunks + uint32_t ByteTriplet = (A << 18) | (B << 12) | (C << 6) | D; + + // Extract the 3 bytes + *DecodedBytes++ = static_cast(ByteTriplet >> 16); + if (PadCount < 2) + { + *DecodedBytes++ = static_cast((ByteTriplet >> 8) & 0xFF); + } + if (PadCount < 1) + { + *DecodedBytes++ = static_cast(ByteTriplet & 0xFF); + } + + Source += 4; + Length -= 4; + } + + OutLength = uint32_t(DecodedBytes - Dest); + return true; +} + +template bool Base64::Decode(const char* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); +template bool Base64::Decode(const wchar_t* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +using namespace std::string_literals; + +TEST_CASE("Base64") +{ + auto EncodeString = [](std::string_view Input) -> std::string { + std::string Result; + Result.resize(Base64::GetEncodedDataSize(uint32_t(Input.size()))); + Base64::Encode(reinterpret_cast(Input.data()), uint32_t(Input.size()), Result.data()); + return Result; + }; + + auto DecodeString = [](std::string_view Input) -> std::string { + std::string Result; + Result.resize(Base64::GetMaxDecodedDataSize(uint32_t(Input.size()))); + uint32_t DecodedLength = 0; + bool Success = Base64::Decode(Input.data(), uint32_t(Input.size()), reinterpret_cast(Result.data()), DecodedLength); + CHECK(Success); + Result.resize(DecodedLength); + return Result; + }; + + SUBCASE("Encode") + { + CHECK(EncodeString("") == ""s); + CHECK(EncodeString("f") == "Zg=="s); + CHECK(EncodeString("fo") == "Zm8="s); + CHECK(EncodeString("foo") == "Zm9v"s); + CHECK(EncodeString("foob") == "Zm9vYg=="s); + CHECK(EncodeString("fooba") == "Zm9vYmE="s); + CHECK(EncodeString("foobar") == "Zm9vYmFy"s); + } + + SUBCASE("Decode") + { + CHECK(DecodeString("") == ""s); + CHECK(DecodeString("Zg==") == "f"s); + CHECK(DecodeString("Zm8=") == "fo"s); + CHECK(DecodeString("Zm9v") == "foo"s); + CHECK(DecodeString("Zm9vYg==") == "foob"s); + CHECK(DecodeString("Zm9vYmE=") == "fooba"s); + CHECK(DecodeString("Zm9vYmFy") == "foobar"s); + } + + SUBCASE("RoundTrip") + { + auto RoundTrip = [&](const std::string& Input) { + std::string Encoded = EncodeString(Input); + std::string Decoded = DecodeString(Encoded); + CHECK(Decoded == Input); + }; + + RoundTrip("Hello, World!"); + RoundTrip("Base64 encoding test with various lengths"); + RoundTrip("A"); + RoundTrip("AB"); + RoundTrip("ABC"); + RoundTrip("ABCD"); + RoundTrip("\x00\x01\x02\xff\xfe\xfd"s); + } + + SUBCASE("BinaryRoundTrip") + { + // Test with all byte values 0-255 + uint8_t AllBytes[256]; + for (int i = 0; i < 256; ++i) + { + AllBytes[i] = static_cast(i); + } + + char Encoded[Base64::GetEncodedDataSize(256) + 1]; + Base64::Encode(AllBytes, 256, Encoded); + + uint8_t Decoded[256]; + uint32_t DecodedLength = 0; + bool Success = Base64::Decode(Encoded, uint32_t(strlen(Encoded)), Decoded, DecodedLength); + CHECK(Success); + CHECK(DecodedLength == 256); + CHECK(memcmp(AllBytes, Decoded, 256) == 0); + } + + SUBCASE("DecodeInvalidInput") + { + uint8_t Dest[64]; + uint32_t OutLength = 0; + + // Length not a multiple of 4 + CHECK_FALSE(Base64::Decode("abc", 3u, Dest, OutLength)); + + // Invalid character + CHECK_FALSE(Base64::Decode("ab!d", 4u, Dest, OutLength)); + } + + SUBCASE("EncodedDataSize") + { + CHECK(Base64::GetEncodedDataSize(0) == 0); + CHECK(Base64::GetEncodedDataSize(1) == 4); + CHECK(Base64::GetEncodedDataSize(2) == 4); + CHECK(Base64::GetEncodedDataSize(3) == 4); + CHECK(Base64::GetEncodedDataSize(4) == 8); + CHECK(Base64::GetEncodedDataSize(5) == 8); + CHECK(Base64::GetEncodedDataSize(6) == 8); + } + + SUBCASE("MaxDecodedDataSize") + { + CHECK(Base64::GetMaxDecodedDataSize(0) == 0); + CHECK(Base64::GetMaxDecodedDataSize(4) == 3); + CHECK(Base64::GetMaxDecodedDataSize(8) == 6); + CHECK(Base64::GetMaxDecodedDataSize(12) == 9); + } +} + +#endif + } // namespace zen diff --git a/src/zencore/include/zencore/base64.h b/src/zencore/include/zencore/base64.h index 4d78b085f..08d9f3043 100644 --- a/src/zencore/include/zencore/base64.h +++ b/src/zencore/include/zencore/base64.h @@ -11,7 +11,11 @@ struct Base64 template static uint32_t Encode(const uint8_t* Source, uint32_t Length, CharType* Dest); + template + static bool Decode(const CharType* Source, uint32_t Length, uint8_t* Dest, uint32_t& OutLength); + static inline constexpr int32_t GetEncodedDataSize(uint32_t Size) { return ((Size + 2) / 3) * 4; } + static inline constexpr int32_t GetMaxDecodedDataSize(uint32_t Length) { return (Length / 4) * 3; } }; } // namespace zen -- cgit v1.2.3 From 01445315564ab527566ec200e0182d8968a80d6f Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 23 Feb 2026 11:09:11 +0100 Subject: changed command names and descriptions to use class members instead of string literals in zen.cpp (#776) --- src/zen/cmds/admin_cmd.h | 40 +++++++++++++++---- src/zen/cmds/bench_cmd.h | 5 ++- src/zen/cmds/cache_cmd.h | 20 ++++++++-- src/zen/cmds/copy_cmd.h | 5 ++- src/zen/cmds/dedup_cmd.h | 5 ++- src/zen/cmds/info_cmd.h | 5 ++- src/zen/cmds/print_cmd.h | 10 ++++- src/zen/cmds/projectstore_cmd.h | 58 ++++++++++++++++++++------- src/zen/cmds/rpcreplay_cmd.h | 15 +++++-- src/zen/cmds/run_cmd.h | 5 ++- src/zen/cmds/serve_cmd.h | 5 ++- src/zen/cmds/status_cmd.h | 5 ++- src/zen/cmds/top_cmd.h | 10 ++++- src/zen/cmds/trace_cmd.h | 7 ++-- src/zen/cmds/up_cmd.h | 15 +++++-- src/zen/cmds/vfs_cmd.h | 5 ++- src/zen/zen.cpp | 86 ++++++++++++++++++++--------------------- 17 files changed, 211 insertions(+), 90 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/admin_cmd.h b/src/zen/cmds/admin_cmd.h index 87ef8091b..83bcf8893 100644 --- a/src/zen/cmds/admin_cmd.h +++ b/src/zen/cmds/admin_cmd.h @@ -13,6 +13,9 @@ namespace zen { class ScrubCommand : public StorageCommand { public: + static constexpr char Name[] = "scrub"; + static constexpr char Description[] = "Scrub zen storage (verify data integrity)"; + ScrubCommand(); ~ScrubCommand(); @@ -20,7 +23,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"scrub", "Scrub zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_DryRun = false; bool m_NoGc = false; @@ -33,6 +36,9 @@ private: class GcCommand : public StorageCommand { public: + static constexpr char Name[] = "gc"; + static constexpr char Description[] = "Garbage collect zen storage"; + GcCommand(); ~GcCommand(); @@ -40,7 +46,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc", "Garbage collect zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_SmallObjects{false}; bool m_SkipCid{false}; @@ -62,6 +68,9 @@ private: class GcStatusCommand : public StorageCommand { public: + static constexpr char Name[] = "gc-status"; + static constexpr char Description[] = "Garbage collect zen storage status check"; + GcStatusCommand(); ~GcStatusCommand(); @@ -69,7 +78,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc-status", "Garbage collect zen storage status check"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Details = false; }; @@ -77,6 +86,9 @@ private: class GcStopCommand : public StorageCommand { public: + static constexpr char Name[] = "gc-stop"; + static constexpr char Description[] = "Request cancel of running garbage collection in zen storage"; + GcStopCommand(); ~GcStopCommand(); @@ -84,7 +96,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"gc-stop", "Request cancel of running garbage collection in zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; @@ -93,6 +105,9 @@ private: class JobCommand : public ZenCmdBase { public: + static constexpr char Name[] = "jobs"; + static constexpr char Description[] = "Show/cancel zen background jobs"; + JobCommand(); ~JobCommand(); @@ -100,7 +115,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"jobs", "Show/cancel zen background jobs"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::uint64_t m_JobId = 0; bool m_Cancel = 0; @@ -111,6 +126,9 @@ private: class LoggingCommand : public ZenCmdBase { public: + static constexpr char Name[] = "logs"; + static constexpr char Description[] = "Show/control zen logging"; + LoggingCommand(); ~LoggingCommand(); @@ -118,7 +136,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"logs", "Show/control zen logging"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_CacheWriteLog; std::string m_CacheAccessLog; @@ -133,6 +151,9 @@ private: class FlushCommand : public StorageCommand { public: + static constexpr char Name[] = "flush"; + static constexpr char Description[] = "Flush storage"; + FlushCommand(); ~FlushCommand(); @@ -140,7 +161,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"flush", "Flush zen storage"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; @@ -149,6 +170,9 @@ private: class CopyStateCommand : public StorageCommand { public: + static constexpr char Name[] = "copy-state"; + static constexpr char Description[] = "Copy zen server disk state"; + CopyStateCommand(); ~CopyStateCommand(); @@ -156,7 +180,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"copy-state", "Copy zen server disk state"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_DataPath; std::filesystem::path m_TargetPath; bool m_SkipLogs = false; diff --git a/src/zen/cmds/bench_cmd.h b/src/zen/cmds/bench_cmd.h index ed123be75..7fbf85340 100644 --- a/src/zen/cmds/bench_cmd.h +++ b/src/zen/cmds/bench_cmd.h @@ -9,6 +9,9 @@ namespace zen { class BenchCommand : public ZenCmdBase { public: + static constexpr char Name[] = "bench"; + static constexpr char Description[] = "Utility command for benchmarking"; + BenchCommand(); ~BenchCommand(); @@ -17,7 +20,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"bench", "Benchmarking utility command"}; + cxxopts::Options m_Options{Name, Description}; bool m_PurgeStandbyLists = false; bool m_SingleProcess = false; }; diff --git a/src/zen/cmds/cache_cmd.h b/src/zen/cmds/cache_cmd.h index 4dc05bbdc..4f5b90f4d 100644 --- a/src/zen/cmds/cache_cmd.h +++ b/src/zen/cmds/cache_cmd.h @@ -9,6 +9,9 @@ namespace zen { class DropCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "drop"; + static constexpr char Description[] = "Drop cache namespace or bucket"; + DropCommand(); ~DropCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"drop", "Drop cache namespace or bucket"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_NamespaceName; std::string m_BucketName; @@ -25,13 +28,16 @@ private: class CacheInfoCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-info"; + static constexpr char Description[] = "Info on cache, namespace or bucket"; + CacheInfoCommand(); ~CacheInfoCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-info", "Info on cache, namespace or bucket"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_NamespaceName; std::string m_SizeInfoBucketNames; @@ -42,26 +48,32 @@ private: class CacheStatsCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-stats"; + static constexpr char Description[] = "Stats on cache"; + CacheStatsCommand(); ~CacheStatsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-stats", "Stats info on cache"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class CacheDetailsCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "cache-details"; + static constexpr char Description[] = "Details on cache"; + CacheDetailsCommand(); ~CacheDetailsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"cache-details", "Detailed info on cache"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_CSV = false; bool m_Details = false; diff --git a/src/zen/cmds/copy_cmd.h b/src/zen/cmds/copy_cmd.h index e1a5dcb82..757a8e691 100644 --- a/src/zen/cmds/copy_cmd.h +++ b/src/zen/cmds/copy_cmd.h @@ -11,6 +11,9 @@ namespace zen { class CopyCommand : public ZenCmdBase { public: + static constexpr char Name[] = "copy"; + static constexpr char Description[] = "Copy file(s)"; + CopyCommand(); ~CopyCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"copy", "Copy files efficiently"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_CopySource; std::filesystem::path m_CopyTarget; bool m_NoClone = false; diff --git a/src/zen/cmds/dedup_cmd.h b/src/zen/cmds/dedup_cmd.h index 5b8387dd2..835b35e92 100644 --- a/src/zen/cmds/dedup_cmd.h +++ b/src/zen/cmds/dedup_cmd.h @@ -11,6 +11,9 @@ namespace zen { class DedupCommand : public ZenCmdBase { public: + static constexpr char Name[] = "dedup"; + static constexpr char Description[] = "Dedup files"; + DedupCommand(); ~DedupCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"dedup", "Deduplicate files"}; + cxxopts::Options m_Options{Name, Description}; std::vector m_Positional; std::filesystem::path m_DedupSource; std::filesystem::path m_DedupTarget; diff --git a/src/zen/cmds/info_cmd.h b/src/zen/cmds/info_cmd.h index 231565bfd..dc108b8a2 100644 --- a/src/zen/cmds/info_cmd.h +++ b/src/zen/cmds/info_cmd.h @@ -9,6 +9,9 @@ namespace zen { class InfoCommand : public ZenCmdBase { public: + static constexpr char Name[] = "info"; + static constexpr char Description[] = "Show high level Zen server information"; + InfoCommand(); ~InfoCommand(); @@ -17,7 +20,7 @@ public: // virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"info", "Show high level zen store information"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; diff --git a/src/zen/cmds/print_cmd.h b/src/zen/cmds/print_cmd.h index 6c1529b7c..f4a97e218 100644 --- a/src/zen/cmds/print_cmd.h +++ b/src/zen/cmds/print_cmd.h @@ -11,6 +11,9 @@ namespace zen { class PrintCommand : public ZenCmdBase { public: + static constexpr char Name[] = "print"; + static constexpr char Description[] = "Print compact binary object"; + PrintCommand(); ~PrintCommand(); @@ -19,7 +22,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"print", "Print compact binary object"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_Filename; bool m_ShowCbObjectTypeInfo = false; }; @@ -29,6 +32,9 @@ private: class PrintPackageCommand : public ZenCmdBase { public: + static constexpr char Name[] = "printpackage"; + static constexpr char Description[] = "Print compact binary package"; + PrintPackageCommand(); ~PrintPackageCommand(); @@ -37,7 +43,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"printpkg", "Print compact binary package"}; + cxxopts::Options m_Options{Name, Description}; std::filesystem::path m_Filename; bool m_ShowCbObjectTypeInfo = false; }; diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h index 56ef858f5..e415b41b7 100644 --- a/src/zen/cmds/projectstore_cmd.h +++ b/src/zen/cmds/projectstore_cmd.h @@ -16,6 +16,9 @@ class ProjectStoreCommand : public ZenCmdBase class DropProjectCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-drop"; + static constexpr char Description[] = "Drop project or project oplog"; + DropProjectCommand(); ~DropProjectCommand(); @@ -23,7 +26,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-drop", "Drop project or project oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -33,13 +36,16 @@ private: class ProjectInfoCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-info"; + static constexpr char Description[] = "Info on project or project oplog"; + ProjectInfoCommand(); ~ProjectInfoCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-info", "Info on project or project oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -48,6 +54,9 @@ private: class CreateProjectCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-create"; + static constexpr char Description[] = "Create a project"; + CreateProjectCommand(); ~CreateProjectCommand(); @@ -55,7 +64,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-create", "Create project, the project must not already exist."}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectId; std::string m_RootDir; @@ -68,6 +77,9 @@ private: class CreateOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-create"; + static constexpr char Description[] = "Create a project oplog"; + CreateOplogCommand(); ~CreateOplogCommand(); @@ -75,7 +87,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-create", "Create oplog in an existing project, the oplog must not already exist."}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectId; std::string m_OplogId; @@ -86,6 +98,9 @@ private: class ExportOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-export"; + static constexpr char Description[] = "Export project store oplog"; + ExportOplogCommand(); ~ExportOplogCommand(); @@ -93,8 +108,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-export", - "Export project store oplog to cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -145,6 +159,9 @@ private: class ImportOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-import"; + static constexpr char Description[] = "Import project store oplog"; + ImportOplogCommand(); ~ImportOplogCommand(); @@ -152,8 +169,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-import", - "Import project store oplog from cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -198,14 +214,16 @@ private: class SnapshotOplogCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-snapshot"; + static constexpr char Description[] = "Snapshot project store oplog"; + SnapshotOplogCommand(); ~SnapshotOplogCommand(); - virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-snapshot", "Snapshot external file references in project store oplog into zen"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -214,26 +232,32 @@ private: class ProjectStatsCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-stats"; + static constexpr char Description[] = "Stats on project store"; + ProjectStatsCommand(); ~ProjectStatsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-stats", "Stats info on project store"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class ProjectOpDetailsCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "project-op-details"; + static constexpr char Description[] = "Detail info on ops inside a project store oplog"; + ProjectOpDetailsCommand(); ~ProjectOpDetailsCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"project-op-details", "Detail info on ops inside a project store oplog"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Details = false; bool m_OpDetails = false; @@ -247,13 +271,16 @@ private: class OplogMirrorCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-mirror"; + static constexpr char Description[] = "Mirror project store oplog to file system"; + OplogMirrorCommand(); ~OplogMirrorCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-mirror", "Mirror oplog to file system"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; @@ -268,13 +295,16 @@ private: class OplogValidateCommand : public ProjectStoreCommand { public: + static constexpr char Name[] = "oplog-validate"; + static constexpr char Description[] = "Validate oplog for missing references"; + OplogValidateCommand(); ~OplogValidateCommand(); virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"oplog-validate", "Validate oplog for missing references"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; diff --git a/src/zen/cmds/rpcreplay_cmd.h b/src/zen/cmds/rpcreplay_cmd.h index a6363b614..332a3126c 100644 --- a/src/zen/cmds/rpcreplay_cmd.h +++ b/src/zen/cmds/rpcreplay_cmd.h @@ -9,6 +9,9 @@ namespace zen { class RpcStartRecordingCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-start"; + static constexpr char Description[] = "Starts recording of cache rpc requests on a host"; + RpcStartRecordingCommand(); ~RpcStartRecordingCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-start", "Starts recording of cache rpc requests on a host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_RecordingPath; }; @@ -24,6 +27,9 @@ private: class RpcStopRecordingCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-stop"; + static constexpr char Description[] = "Stops recording of cache rpc requests on a host"; + RpcStopRecordingCommand(); ~RpcStopRecordingCommand(); @@ -31,13 +37,16 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-stop", "Stops recording of cache rpc requests on a host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; }; class RpcReplayCommand : public CacheStoreCommand { public: + static constexpr char Name[] = "rpc-record-replay"; + static constexpr char Description[] = "Replays a previously recorded session of rpc requests"; + RpcReplayCommand(); ~RpcReplayCommand(); @@ -45,7 +54,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"rpc-record-replay", "Replays a previously recorded session of cache rpc requests to a target host"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_RecordingPath; bool m_OnHost = false; diff --git a/src/zen/cmds/run_cmd.h b/src/zen/cmds/run_cmd.h index 570a2e63a..300c08c5b 100644 --- a/src/zen/cmds/run_cmd.h +++ b/src/zen/cmds/run_cmd.h @@ -9,6 +9,9 @@ namespace zen { class RunCommand : public ZenCmdBase { public: + static constexpr char Name[] = "run"; + static constexpr char Description[] = "Run command with special options"; + RunCommand(); ~RunCommand(); @@ -17,7 +20,7 @@ public: virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: - cxxopts::Options m_Options{"run", "Run executable"}; + cxxopts::Options m_Options{Name, Description}; int m_RunCount = 0; int m_RunTime = -1; std::string m_BaseDirectory; diff --git a/src/zen/cmds/serve_cmd.h b/src/zen/cmds/serve_cmd.h index ac74981f2..22f430948 100644 --- a/src/zen/cmds/serve_cmd.h +++ b/src/zen/cmds/serve_cmd.h @@ -11,6 +11,9 @@ namespace zen { class ServeCommand : public ZenCmdBase { public: + static constexpr char Name[] = "serve"; + static constexpr char Description[] = "Serve files from a directory"; + ServeCommand(); ~ServeCommand(); @@ -18,7 +21,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"serve", "Serve files from a tree"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; std::string m_ProjectName; std::string m_OplogName; diff --git a/src/zen/cmds/status_cmd.h b/src/zen/cmds/status_cmd.h index dc103a196..df5df3066 100644 --- a/src/zen/cmds/status_cmd.h +++ b/src/zen/cmds/status_cmd.h @@ -11,6 +11,9 @@ namespace zen { class StatusCommand : public ZenCmdBase { public: + static constexpr char Name[] = "status"; + static constexpr char Description[] = "Show zen status"; + StatusCommand(); ~StatusCommand(); @@ -20,7 +23,7 @@ public: private: int GetLockFileEffectivePort() const; - cxxopts::Options m_Options{"status", "Show zen status"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; std::filesystem::path m_DataDir; }; diff --git a/src/zen/cmds/top_cmd.h b/src/zen/cmds/top_cmd.h index 74167ecfd..aeb196558 100644 --- a/src/zen/cmds/top_cmd.h +++ b/src/zen/cmds/top_cmd.h @@ -9,6 +9,9 @@ namespace zen { class TopCommand : public ZenCmdBase { public: + static constexpr char Name[] = "top"; + static constexpr char Description[] = "Monitor zen server activity"; + TopCommand(); ~TopCommand(); @@ -16,12 +19,15 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"top", "Show dev UI"}; + cxxopts::Options m_Options{Name, Description}; }; class PsCommand : public ZenCmdBase { public: + static constexpr char Name[] = "ps"; + static constexpr char Description[] = "Enumerate running zen server instances"; + PsCommand(); ~PsCommand(); @@ -29,7 +35,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"ps", "Enumerate running Zen server instances"}; + cxxopts::Options m_Options{Name, Description}; }; } // namespace zen diff --git a/src/zen/cmds/trace_cmd.h b/src/zen/cmds/trace_cmd.h index a6c9742b7..6eb0ba22b 100644 --- a/src/zen/cmds/trace_cmd.h +++ b/src/zen/cmds/trace_cmd.h @@ -6,11 +6,12 @@ namespace zen { -/** Scrub storage - */ class TraceCommand : public ZenCmdBase { public: + static constexpr char Name[] = "trace"; + static constexpr char Description[] = "Control zen realtime tracing"; + TraceCommand(); ~TraceCommand(); @@ -18,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"trace", "Control zen realtime tracing"}; + cxxopts::Options m_Options{Name, Description}; std::string m_HostName; bool m_Stop = false; std::string m_TraceHost; diff --git a/src/zen/cmds/up_cmd.h b/src/zen/cmds/up_cmd.h index 2e822d5fc..270db7f88 100644 --- a/src/zen/cmds/up_cmd.h +++ b/src/zen/cmds/up_cmd.h @@ -11,6 +11,9 @@ namespace zen { class UpCommand : public ZenCmdBase { public: + static constexpr char Name[] = "up"; + static constexpr char Description[] = "Bring zen server up"; + UpCommand(); ~UpCommand(); @@ -18,7 +21,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"up", "Bring up zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; bool m_ShowConsole = false; bool m_ShowLog = false; @@ -28,6 +31,9 @@ private: class AttachCommand : public ZenCmdBase { public: + static constexpr char Name[] = "attach"; + static constexpr char Description[] = "Add a sponsor process to a running zen service"; + AttachCommand(); ~AttachCommand(); @@ -35,7 +41,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"attach", "Add a sponsor process to a running zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; int m_OwnerPid = 0; std::filesystem::path m_DataDir; @@ -44,6 +50,9 @@ private: class DownCommand : public ZenCmdBase { public: + static constexpr char Name[] = "down"; + static constexpr char Description[] = "Bring zen server down"; + DownCommand(); ~DownCommand(); @@ -51,7 +60,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"down", "Bring down zen service"}; + cxxopts::Options m_Options{Name, Description}; uint16_t m_Port = 0; bool m_ForceTerminate = false; std::filesystem::path m_ProgramBaseDir; diff --git a/src/zen/cmds/vfs_cmd.h b/src/zen/cmds/vfs_cmd.h index 5deaa02fa..9009c774b 100644 --- a/src/zen/cmds/vfs_cmd.h +++ b/src/zen/cmds/vfs_cmd.h @@ -9,6 +9,9 @@ namespace zen { class VfsCommand : public StorageCommand { public: + static constexpr char Name[] = "vfs"; + static constexpr char Description[] = "Manage virtual file system"; + VfsCommand(); ~VfsCommand(); @@ -16,7 +19,7 @@ public: virtual cxxopts::Options& Options() override { return m_Options; } private: - cxxopts::Options m_Options{"vfs", "Manage virtual file system"}; + cxxopts::Options m_Options{Name, Description}; std::string m_Verb; std::string m_HostName; diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 018f77738..bdc2b4003 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -379,56 +379,56 @@ main(int argc, char** argv) const char* CmdSummary; } Commands[] = { // clang-format off - {"attach", &AttachCmd, "Add a sponsor process to a running zen service"}, - {"bench", &BenchCmd, "Utility command for benchmarking"}, - {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description}, - {"cache-details", &CacheDetailsCmd, "Details on cache"}, - {"cache-info", &CacheInfoCmd, "Info on cache, namespace or bucket"}, + {AttachCommand::Name, &AttachCmd, AttachCommand::Description}, + {BenchCommand::Name, &BenchCmd, BenchCommand::Description}, + {BuildsCommand::Name, &BuildsCmd, BuildsCommand::Description}, + {CacheDetailsCommand::Name, &CacheDetailsCmd, CacheDetailsCommand::Description}, + {CacheInfoCommand::Name, &CacheInfoCmd, CacheInfoCommand::Description}, {CacheGetCommand::Name, &CacheGetCmd, CacheGetCommand::Description}, {CacheGenerateCommand::Name, &CacheGenerateCmd, CacheGenerateCommand::Description}, - {"cache-stats", &CacheStatsCmd, "Stats on cache"}, - {"copy", &CopyCmd, "Copy file(s)"}, - {"copy-state", &CopyStateCmd, "Copy zen server disk state"}, - {"dedup", &DedupCmd, "Dedup files"}, - {"down", &DownCmd, "Bring zen server down"}, - {"drop", &DropCmd, "Drop cache namespace or bucket"}, + {CacheStatsCommand::Name, &CacheStatsCmd, CacheStatsCommand::Description}, + {CopyCommand::Name, &CopyCmd, CopyCommand::Description}, + {CopyStateCommand::Name, &CopyStateCmd, CopyStateCommand::Description}, + {DedupCommand::Name, &DedupCmd, DedupCommand::Description}, + {DownCommand::Name, &DownCmd, DownCommand::Description}, + {DropCommand::Name, &DropCmd, DropCommand::Description}, #if ZEN_WITH_COMPUTE_SERVICES {ExecCommand::Name, &ExecCmd, ExecCommand::Description}, #endif - {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"}, - {"gc-stop", &GcStopCmd, "Request cancel of running garbage collection in zen storage"}, - {"gc", &GcCmd, "Garbage collect zen storage"}, - {"info", &InfoCmd, "Show high level Zen server information"}, - {"jobs", &JobCmd, "Show/cancel zen background jobs"}, - {"logs", &LoggingCmd, "Show/control zen logging"}, - {"oplog-create", &CreateOplogCmd, "Create a project oplog"}, - {"oplog-export", &ExportOplogCmd, "Export project store oplog"}, - {"oplog-import", &ImportOplogCmd, "Import project store oplog"}, - {"oplog-mirror", &OplogMirrorCmd, "Mirror project store oplog to file system"}, - {"oplog-snapshot", &SnapshotOplogCmd, "Snapshot project store oplog"}, + {GcStatusCommand::Name, &GcStatusCmd, GcStatusCommand::Description}, + {GcStopCommand::Name, &GcStopCmd, GcStopCommand::Description}, + {GcCommand::Name, &GcCmd, GcCommand::Description}, + {InfoCommand::Name, &InfoCmd, InfoCommand::Description}, + {JobCommand::Name, &JobCmd, JobCommand::Description}, + {LoggingCommand::Name, &LoggingCmd, LoggingCommand::Description}, + {CreateOplogCommand::Name, &CreateOplogCmd, CreateOplogCommand::Description}, + {ExportOplogCommand::Name, &ExportOplogCmd, ExportOplogCommand::Description}, + {ImportOplogCommand::Name, &ImportOplogCmd, ImportOplogCommand::Description}, + {OplogMirrorCommand::Name, &OplogMirrorCmd, OplogMirrorCommand::Description}, + {SnapshotOplogCommand::Name, &SnapshotOplogCmd, SnapshotOplogCommand::Description}, {OplogDownloadCommand::Name, &OplogDownload, OplogDownloadCommand::Description}, - {"oplog-validate", &OplogValidateCmd, "Validate oplog for missing references"}, - {"print", &PrintCmd, "Print compact binary object"}, - {"printpackage", &PrintPkgCmd, "Print compact binary package"}, - {"project-create", &CreateProjectCmd, "Create a project"}, - {"project-op-details", &ProjectOpDetailsCmd, "Detail info on ops inside a project store oplog"}, - {"project-drop", &ProjectDropCmd, "Drop project or project oplog"}, - {"project-info", &ProjectInfoCmd, "Info on project or project oplog"}, - {"project-stats", &ProjectStatsCmd, "Stats on project store"}, - {"ps", &PsCmd, "Enumerate running zen server instances"}, - {"rpc-record-replay", &RpcReplayCmd, "Replays a previously recorded session of rpc requests"}, - {"rpc-record-start", &RpcStartRecordingCmd, "Starts recording of cache rpc requests on a host"}, - {"rpc-record-stop", &RpcStopRecordingCmd, "Stops recording of cache rpc requests on a host"}, - {"run", &RunCmd, "Run command with special options"}, - {"scrub", &ScrubCmd, "Scrub zen storage (verify data integrity)"}, - {"serve", &ServeCmd, "Serve files from a directory"}, - {"status", &StatusCmd, "Show zen status"}, - {"top", &TopCmd, "Monitor zen server activity"}, - {"trace", &TraceCmd, "Control zen realtime tracing"}, - {"up", &UpCmd, "Bring zen server up"}, + {OplogValidateCommand::Name, &OplogValidateCmd, OplogValidateCommand::Description}, + {PrintCommand::Name, &PrintCmd, PrintCommand::Description}, + {PrintPackageCommand::Name, &PrintPkgCmd, PrintPackageCommand::Description}, + {CreateProjectCommand::Name, &CreateProjectCmd, CreateProjectCommand::Description}, + {ProjectOpDetailsCommand::Name, &ProjectOpDetailsCmd, ProjectOpDetailsCommand::Description}, + {DropProjectCommand::Name, &ProjectDropCmd, DropProjectCommand::Description}, + {ProjectInfoCommand::Name, &ProjectInfoCmd, ProjectInfoCommand::Description}, + {ProjectStatsCommand::Name, &ProjectStatsCmd, ProjectStatsCommand::Description}, + {PsCommand::Name, &PsCmd, PsCommand::Description}, + {RpcReplayCommand::Name, &RpcReplayCmd, RpcReplayCommand::Description}, + {RpcStartRecordingCommand::Name, &RpcStartRecordingCmd, RpcStartRecordingCommand::Description}, + {RpcStopRecordingCommand::Name, &RpcStopRecordingCmd, RpcStopRecordingCommand::Description}, + {RunCommand::Name, &RunCmd, RunCommand::Description}, + {ScrubCommand::Name, &ScrubCmd, ScrubCommand::Description}, + {ServeCommand::Name, &ServeCmd, ServeCommand::Description}, + {StatusCommand::Name, &StatusCmd, StatusCommand::Description}, + {TopCommand::Name, &TopCmd, TopCommand::Description}, + {TraceCommand::Name, &TraceCmd, TraceCommand::Description}, + {UpCommand::Name, &UpCmd, UpCommand::Description}, {VersionCommand::Name, &VersionCmd, VersionCommand::Description}, - {"vfs", &VfsCmd, "Manage virtual file system"}, - {"flush", &FlushCmd, "Flush storage"}, + {VfsCommand::Name, &VfsCmd, VfsCommand::Description}, + {FlushCommand::Name, &FlushCmd, FlushCommand::Description}, {WipeCommand::Name, &WipeCmd, WipeCommand::Description}, {WorkspaceCommand::Name, &WorkspaceCmd, WorkspaceCommand::Description}, {WorkspaceShareCommand::Name, &WorkspaceShareCmd, WorkspaceShareCommand::Description}, -- cgit v1.2.3 From 9aac0fd369b87e965fb34b5168646387de7ea1cd Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 23 Feb 2026 11:19:52 +0100 Subject: implement yaml generation (#774) this implements a yaml generation strategy similar to the JSON generation where we just build a string instead of building a ryml tree. This also removes the dependency on ryml for reduced binary/build times. --- src/zencore/compactbinaryyaml.cpp | 427 +++++++++++++++++++++++++++----------- src/zencore/xmake.lua | 2 - 2 files changed, 302 insertions(+), 127 deletions(-) (limited to 'src') diff --git a/src/zencore/compactbinaryyaml.cpp b/src/zencore/compactbinaryyaml.cpp index 5122e952a..b308af418 100644 --- a/src/zencore/compactbinaryyaml.cpp +++ b/src/zencore/compactbinaryyaml.cpp @@ -14,11 +14,6 @@ #include #include -ZEN_THIRD_PARTY_INCLUDES_START -#include -#include -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen { ////////////////////////////////////////////////////////////////////////// @@ -26,193 +21,349 @@ namespace zen { class CbYamlWriter { public: - explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_StrBuilder(InBuilder) { m_NodeStack.push_back(m_Tree.rootref()); } + explicit CbYamlWriter(StringBuilderBase& InBuilder) : m_Builder(InBuilder) {} void WriteField(CbFieldView Field) { - ryml::NodeRef Node; + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); - if (m_IsFirst) + switch (Type) { - Node = Top(); + case CbFieldType::Object: + case CbFieldType::UniformObject: + WriteMapEntries(Field, 0); + break; + case CbFieldType::Array: + case CbFieldType::UniformArray: + WriteSeqEntries(Field, 0); + break; + default: + WriteScalarValue(Field); + m_Builder << '\n'; + break; + } + } + + void WriteMapEntry(CbFieldView Field, int32_t Indent) + { + WriteIndent(Indent); + WriteMapEntryContent(Field, Indent); + } + + void WriteSeqEntry(CbFieldView Field, int32_t Indent) + { + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); - m_IsFirst = false; + if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject) + { + bool First = true; + for (CbFieldView MapChild : Field) + { + if (First) + { + WriteIndent(Indent); + m_Builder << "- "; + First = false; + } + else + { + WriteIndent(Indent + 1); + } + WriteMapEntryContent(MapChild, Indent + 1); + } + } + else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray) + { + WriteIndent(Indent); + m_Builder << "-\n"; + WriteSeqEntries(Field, Indent + 1); } else { - Node = Top().append_child(); + WriteIndent(Indent); + m_Builder << "- "; + WriteScalarValue(Field); + m_Builder << '\n'; } + } - if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) +private: + void WriteMapEntries(CbFieldView MapField, int32_t Indent) + { + for (CbFieldView Child : MapField) { - Node.set_key_serialized(ryml::csubstr((const char*)Name.data(), Name.size())); + WriteIndent(Indent); + WriteMapEntryContent(Child, Indent); } + } + + void WriteMapEntryContent(CbFieldView Field, int32_t Indent) + { + std::u8string_view Name = Field.GetU8Name(); + m_Builder << std::string_view(reinterpret_cast(Name.data()), Name.size()); - switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); + + if (IsContainer(Type)) { - case CbFieldType::Null: - Node.set_val("null"); - break; - case CbFieldType::Object: - case CbFieldType::UniformObject: - Node |= ryml::MAP; - m_NodeStack.push_back(Node); - for (CbFieldView It : Field) + m_Builder << ":\n"; + WriteFieldValue(Field, Indent + 1); + } + else + { + m_Builder << ": "; + WriteScalarValue(Field); + m_Builder << '\n'; + } + } + + void WriteSeqEntries(CbFieldView SeqField, int32_t Indent) + { + for (CbFieldView Child : SeqField) + { + CbValue Accessor = Child.GetValue(); + CbFieldType Type = Accessor.GetType(); + + if (Type == CbFieldType::Object || Type == CbFieldType::UniformObject) + { + bool First = true; + for (CbFieldView MapChild : Child) { - WriteField(It); + if (First) + { + WriteIndent(Indent); + m_Builder << "- "; + First = false; + } + else + { + WriteIndent(Indent + 1); + } + WriteMapEntryContent(MapChild, Indent + 1); } - m_NodeStack.pop_back(); + } + else if (Type == CbFieldType::Array || Type == CbFieldType::UniformArray) + { + WriteIndent(Indent); + m_Builder << "-\n"; + WriteSeqEntries(Child, Indent + 1); + } + else + { + WriteIndent(Indent); + m_Builder << "- "; + WriteScalarValue(Child); + m_Builder << '\n'; + } + } + } + + void WriteFieldValue(CbFieldView Field, int32_t Indent) + { + CbValue Accessor = Field.GetValue(); + CbFieldType Type = Accessor.GetType(); + + switch (Type) + { + case CbFieldType::Object: + case CbFieldType::UniformObject: + WriteMapEntries(Field, Indent); break; case CbFieldType::Array: case CbFieldType::UniformArray: - Node |= ryml::SEQ; - m_NodeStack.push_back(Node); - for (CbFieldView It : Field) - { - WriteField(It); - } - m_NodeStack.pop_back(); + WriteSeqEntries(Field, Indent); break; - case CbFieldType::Binary: - { - ExtendableStringBuilder<256> Builder; - const MemoryView Value = Accessor.AsBinary(); - ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); - - Node.set_key_serialized(Builder.c_str()); - } + case CbFieldType::CustomById: + WriteCustomById(Field.GetValue().AsCustomById(), Indent); break; - case CbFieldType::String: - { - const std::u8string_view U8String = Accessor.AsU8String(); - Node.set_val(ryml::csubstr((const char*)U8String.data(), U8String.size())); - } + case CbFieldType::CustomByName: + WriteCustomByName(Field.GetValue().AsCustomByName(), Indent); + break; + default: + WriteScalarValue(Field); + m_Builder << '\n'; + break; + } + } + + void WriteScalarValue(CbFieldView Field) + { + CbValue Accessor = Field.GetValue(); + switch (Accessor.GetType()) + { + case CbFieldType::Null: + m_Builder << "null"; + break; + case CbFieldType::BoolFalse: + m_Builder << "false"; + break; + case CbFieldType::BoolTrue: + m_Builder << "true"; break; case CbFieldType::IntegerPositive: - Node << Accessor.AsIntegerPositive(); + m_Builder << Accessor.AsIntegerPositive(); break; case CbFieldType::IntegerNegative: - Node << Accessor.AsIntegerNegative(); + m_Builder << Accessor.AsIntegerNegative(); break; case CbFieldType::Float32: if (const float Value = Accessor.AsFloat32(); std::isfinite(Value)) - { - Node << Value; - } + m_Builder.Append(fmt::format("{}", Value)); else - { - Node << "null"; - } + m_Builder << "null"; break; case CbFieldType::Float64: if (const double Value = Accessor.AsFloat64(); std::isfinite(Value)) - { - Node << Value; - } + m_Builder.Append(fmt::format("{}", Value)); else + m_Builder << "null"; + break; + case CbFieldType::String: { - Node << "null"; + const std::u8string_view U8String = Accessor.AsU8String(); + WriteString(std::string_view(reinterpret_cast(U8String.data()), U8String.size())); } break; - case CbFieldType::BoolFalse: - Node << "false"; - break; - case CbFieldType::BoolTrue: - Node << "true"; + case CbFieldType::Hash: + WriteString(Accessor.AsHash().ToHexString()); break; case CbFieldType::ObjectAttachment: case CbFieldType::BinaryAttachment: - Node << Accessor.AsAttachment().ToHexString(); - break; - case CbFieldType::Hash: - Node << Accessor.AsHash().ToHexString(); + WriteString(Accessor.AsAttachment().ToHexString()); break; case CbFieldType::Uuid: - Node << fmt::format("{}", Accessor.AsUuid()); + WriteString(fmt::format("{}", Accessor.AsUuid())); break; case CbFieldType::DateTime: - Node << DateTime(Accessor.AsDateTimeTicks()).ToIso8601(); + WriteString(DateTime(Accessor.AsDateTimeTicks()).ToIso8601()); break; case CbFieldType::TimeSpan: if (const TimeSpan Span(Accessor.AsTimeSpanTicks()); Span.GetDays() == 0) - { - Node << Span.ToString("%h:%m:%s.%n"); - } + WriteString(Span.ToString("%h:%m:%s.%n")); else - { - Node << Span.ToString("%d.%h:%m:%s.%n"); - } + WriteString(Span.ToString("%d.%h:%m:%s.%n")); break; case CbFieldType::ObjectId: - Node << fmt::format("{}", Accessor.AsObjectId()); + WriteString(fmt::format("{}", Accessor.AsObjectId())); break; - case CbFieldType::CustomById: - { - CbCustomById Custom = Accessor.AsCustomById(); + case CbFieldType::Binary: + WriteBase64(Accessor.AsBinary()); + break; + default: + ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); + break; + } + } - Node |= ryml::MAP; + void WriteCustomById(CbCustomById Custom, int32_t Indent) + { + WriteIndent(Indent); + m_Builder << "Id: "; + m_Builder.Append(fmt::format("{}", Custom.Id)); + m_Builder << '\n'; + + WriteIndent(Indent); + m_Builder << "Data: "; + WriteBase64(Custom.Data); + m_Builder << '\n'; + } - ryml::NodeRef IdNode = Node.append_child(); - IdNode.set_key("Id"); - IdNode.set_val_serialized(fmt::format("{}", Custom.Id)); + void WriteCustomByName(CbCustomByName Custom, int32_t Indent) + { + WriteIndent(Indent); + m_Builder << "Name: "; + WriteString(std::string_view(reinterpret_cast(Custom.Name.data()), Custom.Name.size())); + m_Builder << '\n'; + + WriteIndent(Indent); + m_Builder << "Data: "; + WriteBase64(Custom.Data); + m_Builder << '\n'; + } - ryml::NodeRef DataNode = Node.append_child(); - DataNode.set_key("Data"); + void WriteBase64(MemoryView Value) + { + ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); + ExtendableStringBuilder<256> Buf; + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Buf.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast(Value.GetData()), uint32_t(Value.GetSize()), Buf.Data() + EncodedIndex); + WriteString(Buf.ToView()); + } - ExtendableStringBuilder<256> Builder; - const MemoryView& Value = Custom.Data; - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + void WriteString(std::string_view Str) + { + if (NeedsQuoting(Str)) + { + m_Builder << '\''; + for (char C : Str) + { + if (C == '\'') + m_Builder << "''"; + else + m_Builder << C; + } + m_Builder << '\''; + } + else + { + m_Builder << Str; + } + } - DataNode.set_val_serialized(Builder.c_str()); - } - break; - case CbFieldType::CustomByName: - { - CbCustomByName Custom = Accessor.AsCustomByName(); + void WriteIndent(int32_t Indent) + { + for (int32_t I = 0; I < Indent; ++I) + m_Builder << " "; + } - Node |= ryml::MAP; + static bool NeedsQuoting(std::string_view Str) + { + if (Str.empty()) + return false; - ryml::NodeRef NameNode = Node.append_child(); - NameNode.set_key("Name"); - std::string_view Name = std::string_view((const char*)Custom.Name.data(), Custom.Name.size()); - NameNode.set_val_serialized(std::string(Name)); + char First = Str[0]; + if (First == ' ' || First == '\n' || First == '\t' || First == '\r' || First == '*' || First == '&' || First == '%' || + First == '@' || First == '`') + return true; - ryml::NodeRef DataNode = Node.append_child(); - DataNode.set_key("Data"); + if (Str.size() >= 2 && Str[0] == '<' && Str[1] == '<') + return true; - ExtendableStringBuilder<256> Builder; - const MemoryView& Value = Custom.Data; - const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); - const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); - Base64::Encode(static_cast(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + char Last = Str.back(); + if (Last == ' ' || Last == '\n' || Last == '\t' || Last == '\r') + return true; - DataNode.set_val_serialized(Builder.c_str()); - } - break; - default: - ZEN_ASSERT_FORMAT(false, "invalid field type: {}", uint8_t(Accessor.GetType())); - break; + for (char C : Str) + { + if (C == '#' || C == ':' || C == '-' || C == '?' || C == ',' || C == '\n' || C == '{' || C == '}' || C == '[' || C == ']' || + C == '\'' || C == '"') + return true; } - if (m_NodeStack.size() == 1) + return false; + } + + static bool IsContainer(CbFieldType Type) + { + switch (Type) { - std::string Yaml = ryml::emitrs_yaml(m_Tree); - m_StrBuilder << Yaml; + case CbFieldType::Object: + case CbFieldType::UniformObject: + case CbFieldType::Array: + case CbFieldType::UniformArray: + case CbFieldType::CustomById: + case CbFieldType::CustomByName: + return true; + default: + return false; } } -private: - StringBuilderBase& m_StrBuilder; - bool m_IsFirst = true; - - ryml::Tree m_Tree; - std::vector m_NodeStack; - ryml::NodeRef& Top() { return m_NodeStack.back(); } + StringBuilderBase& m_Builder; }; void @@ -229,6 +380,32 @@ CompactBinaryToYaml(const CbArrayView& Array, StringBuilderBase& Builder) Writer.WriteField(Array.AsFieldView()); } +void +CompactBinaryToYaml(MemoryView Data, StringBuilderBase& InBuilder) +{ + std::vector Fields = ReadCompactBinaryStream(Data); + if (Fields.empty()) + return; + + CbYamlWriter Writer(InBuilder); + if (Fields.size() == 1) + { + Writer.WriteField(Fields[0]); + return; + } + + if (Fields[0].HasName()) + { + for (const CbFieldView& Field : Fields) + Writer.WriteMapEntry(Field, 0); + } + else + { + for (const CbFieldView& Field : Fields) + Writer.WriteSeqEntry(Field, 0); + } +} + #if ZEN_WITH_TESTS void cbyaml_forcelink() diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index a3fd4dacb..9a67175a0 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -33,8 +33,6 @@ target('zencore') add_deps("timesinceprocessstart") add_deps("doctest") add_deps("fmt") - add_deps("ryml") - add_packages("json11") if is_plat("linux", "macosx") then -- cgit v1.2.3 From 3c89c486338890ce39ddebe5be4722a09e85701a Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 24 Feb 2026 13:23:52 +0100 Subject: Fix correctness and concurrency bugs found during code review zenstore fixes: - cas.cpp: GetFileCasResults Results param passed by value instead of reference (large chunk results were silently lost) - structuredcachestore.cpp: MissCount unconditionally incremented (counted hits as misses) - cacherpc.cpp: Wrong boolean in Incomplete response array (all entries marked incomplete) - cachedisklayer.cpp: sizeof(sizeof(...)) in two validation checks computed sizeof(size_t) instead of struct size - buildstore.cpp: Wrong hash tracked in GC key list (BlobHash pushed twice instead of MetadataHash) - buildstore.cpp: Removed duplicate m_LastAccessTimeUpdateCount increment in PutBlob zenserver fixes: - httpbuildstore.cpp: Reversed subtraction in HTTP range calculation (unsigned underflow) - hubservice.cpp: Deadlock in Provision() calling Wake() while holding m_Lock (extracted WakeLocked helper) - zipfs.cpp: Data race in GetFile() lazy initialization (added RwLock with shared/exclusive paths) Co-Authored-By: Claude Opus 4.6 --- src/zenserver/frontend/zipfs.cpp | 20 ++++++++++++++++---- src/zenserver/frontend/zipfs.h | 2 ++ src/zenserver/hub/hubservice.cpp | 12 +++++++++--- src/zenserver/storage/buildstore/httpbuildstore.cpp | 2 +- src/zenstore/buildstore/buildstore.cpp | 3 +-- src/zenstore/cache/cachedisklayer.cpp | 4 ++-- src/zenstore/cache/cacherpc.cpp | 2 +- src/zenstore/cache/structuredcachestore.cpp | 5 ++++- src/zenstore/cas.cpp | 12 ++++++------ 9 files changed, 42 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp index f9c2bc8ff..42df0520f 100644 --- a/src/zenserver/frontend/zipfs.cpp +++ b/src/zenserver/frontend/zipfs.cpp @@ -149,13 +149,25 @@ ZipFs::ZipFs(IoBuffer&& Buffer) IoBuffer ZipFs::GetFile(const std::string_view& FileName) const { - FileMap::iterator Iter = m_Files.find(FileName); - if (Iter == m_Files.end()) { - return {}; + RwLock::SharedLockScope _(m_FilesLock); + + FileMap::const_iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + const FileItem& Item = Iter->second; + if (Item.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); + } } - FileItem& Item = Iter->second; + RwLock::ExclusiveLockScope _(m_FilesLock); + + FileItem& Item = m_Files.find(FileName)->second; if (Item.GetSize() > 0) { return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h index 1fa7da451..19f96567c 100644 --- a/src/zenserver/frontend/zipfs.h +++ b/src/zenserver/frontend/zipfs.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include @@ -20,6 +21,7 @@ public: private: using FileItem = MemoryView; using FileMap = std::unordered_map; + mutable RwLock m_FilesLock; FileMap mutable m_Files; IoBuffer m_Buffer; }; diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index 4d9da3a57..a00446a75 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -151,6 +151,7 @@ struct StorageServerInstance inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); } private: + void WakeLocked(); RwLock m_Lock; std::string m_ModuleId; std::atomic m_IsProvisioned{false}; @@ -211,7 +212,7 @@ StorageServerInstance::Provision() if (m_IsHibernated) { - Wake(); + WakeLocked(); } else { @@ -294,9 +295,14 @@ StorageServerInstance::Hibernate() void StorageServerInstance::Wake() { - // Start server in-place using existing data - RwLock::ExclusiveLockScope _(m_Lock); + WakeLocked(); +} + +void +StorageServerInstance::WakeLocked() +{ + // Start server in-place using existing data if (!m_IsHibernated) { diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index f5ba30616..bf7afcc02 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -185,7 +185,7 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) { const HttpRange& Range = Ranges.front(); const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = Range.Start < BlobSize ? Range.Start - BlobSize : 0; + const uint64_t MaxBlobSize = Range.Start < BlobSize ? BlobSize - Range.Start : 0; const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize); if (Range.Start + RangeSize > BlobSize) { diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index 04a0781d3..aa37e75fe 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -266,13 +266,12 @@ BuildStore::PutBlob(const IoHash& BlobHash, const IoBuffer& Payload) m_BlobLookup.insert({BlobHash, NewBlobIndex}); } - m_LastAccessTimeUpdateCount++; if (m_TrackedBlobKeys) { m_TrackedBlobKeys->push_back(BlobHash); if (MetadataHash != IoHash::Zero) { - m_TrackedBlobKeys->push_back(BlobHash); + m_TrackedBlobKeys->push_back(MetadataHash); } } } diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index ead7e4f3a..b73b3e6fb 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -626,7 +626,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B return false; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(BucketMetaHeader))) / sizeof(ManifestData); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(BucketMetaHeader)) / sizeof(ManifestData); if (Header.EntryCount > ExpectedEntryCount) { ZEN_WARN( @@ -1057,7 +1057,7 @@ ZenCacheDiskLayer::CacheBucket::ReadIndexFile(RwLock::ExclusiveLockScope&, const return 0; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(cache::impl::CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(cache::impl::CacheBucketIndexHeader)) / sizeof(DiskIndexEntry); if (Header.EntryCount > ExpectedEntryCount) { return 0; diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index 94abcf547..e1fd0a3e6 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -966,7 +966,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb } else { - ResponseObject.AddBool(true); + ResponseObject.AddBool(false); } } ResponseObject.EndArray(); diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 52b494e45..4e8475293 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -608,7 +608,10 @@ ZenCacheStore::GetBatch::Commit() m_CacheStore.m_HitCount++; OpScope.SetBytes(Result.Value.GetSize()); } - m_CacheStore.m_MissCount++; + else + { + m_CacheStore.m_MissCount++; + } } } } diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp index ed017988f..7402d92d3 100644 --- a/src/zenstore/cas.cpp +++ b/src/zenstore/cas.cpp @@ -300,12 +300,12 @@ GetCompactCasResults(CasContainerStrategy& Strategy, }; static void -GetFileCasResults(FileCasStrategy& Strategy, - CasStore::InsertMode Mode, - std::span Data, - std::span ChunkHashes, - std::span Indexes, - std::vector Results) +GetFileCasResults(FileCasStrategy& Strategy, + CasStore::InsertMode Mode, + std::span Data, + std::span ChunkHashes, + std::span Indexes, + std::vector& Results) { for (size_t Index : Indexes) { -- cgit v1.2.3 From 075bac3ca870a1297e9f62230d56e63aec13a77d Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 24 Feb 2026 13:36:44 +0100 Subject: Revert "Fix correctness and concurrency bugs found during code review" This reverts commit 3c89c486338890ce39ddebe5be4722a09e85701a. --- src/zenserver/frontend/zipfs.cpp | 20 ++++---------------- src/zenserver/frontend/zipfs.h | 2 -- src/zenserver/hub/hubservice.cpp | 12 +++--------- src/zenserver/storage/buildstore/httpbuildstore.cpp | 2 +- src/zenstore/buildstore/buildstore.cpp | 3 ++- src/zenstore/cache/cachedisklayer.cpp | 4 ++-- src/zenstore/cache/cacherpc.cpp | 2 +- src/zenstore/cache/structuredcachestore.cpp | 5 +---- src/zenstore/cas.cpp | 12 ++++++------ 9 files changed, 20 insertions(+), 42 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp index 42df0520f..f9c2bc8ff 100644 --- a/src/zenserver/frontend/zipfs.cpp +++ b/src/zenserver/frontend/zipfs.cpp @@ -149,25 +149,13 @@ ZipFs::ZipFs(IoBuffer&& Buffer) IoBuffer ZipFs::GetFile(const std::string_view& FileName) const { + FileMap::iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) { - RwLock::SharedLockScope _(m_FilesLock); - - FileMap::const_iterator Iter = m_Files.find(FileName); - if (Iter == m_Files.end()) - { - return {}; - } - - const FileItem& Item = Iter->second; - if (Item.GetSize() > 0) - { - return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); - } + return {}; } - RwLock::ExclusiveLockScope _(m_FilesLock); - - FileItem& Item = m_Files.find(FileName)->second; + FileItem& Item = Iter->second; if (Item.GetSize() > 0) { return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h index 19f96567c..1fa7da451 100644 --- a/src/zenserver/frontend/zipfs.h +++ b/src/zenserver/frontend/zipfs.h @@ -3,7 +3,6 @@ #pragma once #include -#include #include @@ -21,7 +20,6 @@ public: private: using FileItem = MemoryView; using FileMap = std::unordered_map; - mutable RwLock m_FilesLock; FileMap mutable m_Files; IoBuffer m_Buffer; }; diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index a00446a75..4d9da3a57 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -151,7 +151,6 @@ struct StorageServerInstance inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); } private: - void WakeLocked(); RwLock m_Lock; std::string m_ModuleId; std::atomic m_IsProvisioned{false}; @@ -212,7 +211,7 @@ StorageServerInstance::Provision() if (m_IsHibernated) { - WakeLocked(); + Wake(); } else { @@ -294,16 +293,11 @@ StorageServerInstance::Hibernate() void StorageServerInstance::Wake() -{ - RwLock::ExclusiveLockScope _(m_Lock); - WakeLocked(); -} - -void -StorageServerInstance::WakeLocked() { // Start server in-place using existing data + RwLock::ExclusiveLockScope _(m_Lock); + if (!m_IsHibernated) { ZEN_WARN("Attempted to wake storage server instance for module '{}' which is not hibernated", m_ModuleId); diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index bf7afcc02..f5ba30616 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -185,7 +185,7 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) { const HttpRange& Range = Ranges.front(); const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = Range.Start < BlobSize ? BlobSize - Range.Start : 0; + const uint64_t MaxBlobSize = Range.Start < BlobSize ? Range.Start - BlobSize : 0; const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize); if (Range.Start + RangeSize > BlobSize) { diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index aa37e75fe..04a0781d3 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -266,12 +266,13 @@ BuildStore::PutBlob(const IoHash& BlobHash, const IoBuffer& Payload) m_BlobLookup.insert({BlobHash, NewBlobIndex}); } + m_LastAccessTimeUpdateCount++; if (m_TrackedBlobKeys) { m_TrackedBlobKeys->push_back(BlobHash); if (MetadataHash != IoHash::Zero) { - m_TrackedBlobKeys->push_back(MetadataHash); + m_TrackedBlobKeys->push_back(BlobHash); } } } diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index b73b3e6fb..ead7e4f3a 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -626,7 +626,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B return false; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(BucketMetaHeader)) / sizeof(ManifestData); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(BucketMetaHeader))) / sizeof(ManifestData); if (Header.EntryCount > ExpectedEntryCount) { ZEN_WARN( @@ -1057,7 +1057,7 @@ ZenCacheDiskLayer::CacheBucket::ReadIndexFile(RwLock::ExclusiveLockScope&, const return 0; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(cache::impl::CacheBucketIndexHeader)) / sizeof(DiskIndexEntry); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(cache::impl::CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); if (Header.EntryCount > ExpectedEntryCount) { return 0; diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index e1fd0a3e6..94abcf547 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -966,7 +966,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb } else { - ResponseObject.AddBool(false); + ResponseObject.AddBool(true); } } ResponseObject.EndArray(); diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 4e8475293..52b494e45 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -608,10 +608,7 @@ ZenCacheStore::GetBatch::Commit() m_CacheStore.m_HitCount++; OpScope.SetBytes(Result.Value.GetSize()); } - else - { - m_CacheStore.m_MissCount++; - } + m_CacheStore.m_MissCount++; } } } diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp index 7402d92d3..ed017988f 100644 --- a/src/zenstore/cas.cpp +++ b/src/zenstore/cas.cpp @@ -300,12 +300,12 @@ GetCompactCasResults(CasContainerStrategy& Strategy, }; static void -GetFileCasResults(FileCasStrategy& Strategy, - CasStore::InsertMode Mode, - std::span Data, - std::span ChunkHashes, - std::span Indexes, - std::vector& Results) +GetFileCasResults(FileCasStrategy& Strategy, + CasStore::InsertMode Mode, + std::span Data, + std::span ChunkHashes, + std::span Indexes, + std::vector Results) { for (size_t Index : Indexes) { -- cgit v1.2.3 From 5c5e12d1f02bb7cc1f42750e47a2735dc933c194 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 24 Feb 2026 14:56:57 +0100 Subject: Various bug fixes (#778) zencore fixes: - filesystem.cpp: ReadFile error reporting logic - compactbinaryvalue.h: CbValue::As*String error reporting logic zenhttp fixes: - httpasio BindAcceptor would `return 0;` in a function returning `std::string` (UB) - httpsys async workpool initialization race zenstore fixes: - cas.cpp: GetFileCasResults Results param passed by value instead of reference (large chunk results were silently lost) - structuredcachestore.cpp: MissCount unconditionally incremented (counted hits as misses) - cacherpc.cpp: Wrong boolean in Incomplete response array (all entries marked incomplete) - cachedisklayer.cpp: sizeof(sizeof(...)) in two validation checks computed sizeof(size_t) instead of struct size - buildstore.cpp: Wrong hash tracked in GC key list (BlobHash pushed twice instead of MetadataHash) - buildstore.cpp: Removed duplicate m_LastAccessTimeUpdateCount increment in PutBlob zenserver fixes: - httpbuildstore.cpp: Reversed subtraction in HTTP range calculation (unsigned underflow) - hubservice.cpp: Deadlock in Provision() calling Wake() while holding m_Lock (extracted WakeLocked helper) - zipfs.cpp: Data race in GetFile() lazy initialization (added RwLock with shared/exclusive paths) --- src/zencompute-test/xmake.lua | 1 - src/zencompute/xmake.lua | 2 -- src/zencore/filesystem.cpp | 16 +++------------ src/zencore/include/zencore/compactbinaryvalue.h | 24 ++++++++++++++-------- src/zencore/memtrack/callstacktrace.cpp | 8 ++++---- src/zencore/string.cpp | 4 ++++ src/zenhttp/servers/httpasio.cpp | 4 ++-- src/zenhttp/servers/httpsys.cpp | 17 ++++++++------- src/zenserver/frontend/frontend.cpp | 9 +++++--- src/zenserver/frontend/frontend.h | 7 ++++--- src/zenserver/frontend/zipfs.cpp | 20 ++++++++++++++---- src/zenserver/frontend/zipfs.h | 8 ++++---- src/zenserver/hub/hubservice.cpp | 12 ++++++++--- .../storage/buildstore/httpbuildstore.cpp | 2 +- src/zenstore/buildstore/buildstore.cpp | 3 +-- src/zenstore/cache/cachedisklayer.cpp | 4 ++-- src/zenstore/cache/cacherpc.cpp | 2 +- src/zenstore/cache/structuredcachestore.cpp | 5 ++++- src/zenstore/cas.cpp | 12 +++++------ src/zentest-appstub/xmake.lua | 2 -- 20 files changed, 93 insertions(+), 69 deletions(-) (limited to 'src') diff --git a/src/zencompute-test/xmake.lua b/src/zencompute-test/xmake.lua index 64a3c7703..1207bdefd 100644 --- a/src/zencompute-test/xmake.lua +++ b/src/zencompute-test/xmake.lua @@ -6,4 +6,3 @@ target("zencompute-test") add_headerfiles("**.h") add_files("*.cpp") add_deps("zencompute", "zencore") - add_packages("vcpkg::doctest") diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua index c710b662d..50877508c 100644 --- a/src/zencompute/xmake.lua +++ b/src/zencompute/xmake.lua @@ -7,5 +7,3 @@ target('zencompute') add_files("**.cpp") add_includedirs("include", {public=true}) add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") - add_packages("vcpkg::gsl-lite") - add_packages("vcpkg::spdlog", "vcpkg::cxxopts") diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 1a4ee4b9b..553897407 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -1326,11 +1326,6 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin { BytesRead = size_t(dwNumberOfBytesRead); } - else if ((BytesRead != NumberOfBytesToRead)) - { - Ec = MakeErrorCode(ERROR_HANDLE_EOF); - return; - } else { Ec = MakeErrorCodeFromLastError(); @@ -1344,20 +1339,15 @@ ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uin { BytesRead = size_t(ReadResult); } - else if ((BytesRead != NumberOfBytesToRead)) - { - Ec = MakeErrorCode(EIO); - return; - } else { Ec = MakeErrorCodeFromLastError(); return; } #endif - Size -= NumberOfBytesToRead; - FileOffset += NumberOfBytesToRead; - Data = reinterpret_cast(Data) + NumberOfBytesToRead; + Size -= BytesRead; + FileOffset += BytesRead; + Data = reinterpret_cast(Data) + BytesRead; } } diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h index aa2d2821d..4ce8009b8 100644 --- a/src/zencore/include/zencore/compactbinaryvalue.h +++ b/src/zencore/include/zencore/compactbinaryvalue.h @@ -128,17 +128,21 @@ CbValue::AsString(CbFieldError* OutError, std::string_view Default) const uint32_t ValueSizeByteCount; const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); - if (OutError) + if (ValueSize >= (uint64_t(1) << 31)) { - if (ValueSize >= (uint64_t(1) << 31)) + if (OutError) { *OutError = CbFieldError::RangeError; - return Default; } + return Default; + } + + if (OutError) + { *OutError = CbFieldError::None; } - return std::string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); + return std::string_view(Chars + ValueSizeByteCount, size_t(ValueSize)); } inline std::u8string_view @@ -148,17 +152,21 @@ CbValue::AsU8String(CbFieldError* OutError, std::u8string_view Default) const uint32_t ValueSizeByteCount; const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); - if (OutError) + if (ValueSize >= (uint64_t(1) << 31)) { - if (ValueSize >= (uint64_t(1) << 31)) + if (OutError) { *OutError = CbFieldError::RangeError; - return Default; } + return Default; + } + + if (OutError) + { *OutError = CbFieldError::None; } - return std::u8string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); + return std::u8string_view(Chars + ValueSizeByteCount, size_t(ValueSize)); } inline uint64_t diff --git a/src/zencore/memtrack/callstacktrace.cpp b/src/zencore/memtrack/callstacktrace.cpp index a5b7fede6..4a7068568 100644 --- a/src/zencore/memtrack/callstacktrace.cpp +++ b/src/zencore/memtrack/callstacktrace.cpp @@ -169,13 +169,13 @@ private: std::atomic_uint64_t Key; std::atomic_uint32_t Value; - inline uint64 GetKey() const { return Key.load(std::memory_order_relaxed); } + inline uint64 GetKey() const { return Key.load(std::memory_order_acquire); } inline uint32_t GetValue() const { return Value.load(std::memory_order_relaxed); } - inline bool IsEmpty() const { return Key.load(std::memory_order_relaxed) == 0; } + inline bool IsEmpty() const { return Key.load(std::memory_order_acquire) == 0; } inline void SetKeyValue(uint64_t InKey, uint32_t InValue) { - Value.store(InValue, std::memory_order_release); - Key.store(InKey, std::memory_order_relaxed); + Value.store(InValue, std::memory_order_relaxed); + Key.store(InKey, std::memory_order_release); } static inline uint32_t KeyHash(uint64_t Key) { return static_cast(Key); } static inline void ClearEntries(FEncounteredCallstackSetEntry* Entries, int32_t EntryCount) diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index 0ee863b74..a9aed6309 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -24,6 +24,10 @@ utf16to8_impl(u16bit_iterator StartIt, u16bit_iterator EndIt, ::zen::StringBuild // Take care of surrogate pairs first if (utf8::internal::is_lead_surrogate(cp)) { + if (StartIt == EndIt) + { + break; + } uint32_t trail_surrogate = utf8::internal::mask16(*StartIt++); cp = (cp << 10) + trail_surrogate + utf8::internal::SURROGATE_OFFSET; } diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 1c0ebef90..fbc7fe401 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -89,10 +89,10 @@ IsIPv6AvailableSysctl(void) char buf[16]; if (fgets(buf, sizeof(buf), f)) { - fclose(f); // 0 means IPv6 enabled, 1 means disabled val = atoi(buf); } + fclose(f); } return val == 0; @@ -1544,7 +1544,7 @@ struct HttpAcceptor { ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - return 0; + return {}; } if (EffectivePort != BasePort) diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index c640ba90b..6995ffca9 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -25,6 +25,7 @@ # include # include "iothreadpool.h" +# include # include namespace zen { @@ -129,8 +130,8 @@ private: std::unique_ptr m_IoThreadPool; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic m_AsyncWorkPool = nullptr; std::vector m_BaseUris; // eg: http://*:nnnn/ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; @@ -1032,8 +1033,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1323,17 +1326,17 @@ HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { - m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp index 2b157581f..1cf451e91 100644 --- a/src/zenserver/frontend/frontend.cpp +++ b/src/zenserver/frontend/frontend.cpp @@ -38,7 +38,7 @@ HttpFrontendService::HttpFrontendService(std::filesystem::path Directory, HttpSt #if ZEN_EMBED_HTML_ZIP // Load an embedded Zip archive IoBuffer HtmlZipDataBuffer(IoBuffer::Wrap, gHtmlZipData, sizeof(gHtmlZipData) - 1); - m_ZipFs = ZipFs(std::move(HtmlZipDataBuffer)); + m_ZipFs = std::make_unique(std::move(HtmlZipDataBuffer)); #endif if (m_Directory.empty() && !m_ZipFs) @@ -157,9 +157,12 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) } } - if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri)) + if (m_ZipFs) { - return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + if (IoBuffer FileBuffer = m_ZipFs->GetFile(Uri)) + { + return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + } } Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h index 84ffaac42..6d8585b72 100644 --- a/src/zenserver/frontend/frontend.h +++ b/src/zenserver/frontend/frontend.h @@ -7,6 +7,7 @@ #include "zipfs.h" #include +#include namespace zen { @@ -20,9 +21,9 @@ public: virtual void HandleStatusRequest(HttpServerRequest& Request) override; private: - ZipFs m_ZipFs; - std::filesystem::path m_Directory; - HttpStatusService& m_StatusService; + std::unique_ptr m_ZipFs; + std::filesystem::path m_Directory; + HttpStatusService& m_StatusService; }; } // namespace zen diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp index f9c2bc8ff..42df0520f 100644 --- a/src/zenserver/frontend/zipfs.cpp +++ b/src/zenserver/frontend/zipfs.cpp @@ -149,13 +149,25 @@ ZipFs::ZipFs(IoBuffer&& Buffer) IoBuffer ZipFs::GetFile(const std::string_view& FileName) const { - FileMap::iterator Iter = m_Files.find(FileName); - if (Iter == m_Files.end()) { - return {}; + RwLock::SharedLockScope _(m_FilesLock); + + FileMap::const_iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + const FileItem& Item = Iter->second; + if (Item.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); + } } - FileItem& Item = Iter->second; + RwLock::ExclusiveLockScope _(m_FilesLock); + + FileItem& Item = m_Files.find(FileName)->second; if (Item.GetSize() > 0) { return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h index 1fa7da451..645121693 100644 --- a/src/zenserver/frontend/zipfs.h +++ b/src/zenserver/frontend/zipfs.h @@ -3,23 +3,23 @@ #pragma once #include +#include #include namespace zen { -////////////////////////////////////////////////////////////////////////// class ZipFs { public: - ZipFs() = default; - ZipFs(IoBuffer&& Buffer); + explicit ZipFs(IoBuffer&& Buffer); + IoBuffer GetFile(const std::string_view& FileName) const; - inline operator bool() const { return !m_Files.empty(); } private: using FileItem = MemoryView; using FileMap = std::unordered_map; + mutable RwLock m_FilesLock; FileMap mutable m_Files; IoBuffer m_Buffer; }; diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index 4d9da3a57..a00446a75 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -151,6 +151,7 @@ struct StorageServerInstance inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); } private: + void WakeLocked(); RwLock m_Lock; std::string m_ModuleId; std::atomic m_IsProvisioned{false}; @@ -211,7 +212,7 @@ StorageServerInstance::Provision() if (m_IsHibernated) { - Wake(); + WakeLocked(); } else { @@ -294,9 +295,14 @@ StorageServerInstance::Hibernate() void StorageServerInstance::Wake() { - // Start server in-place using existing data - RwLock::ExclusiveLockScope _(m_Lock); + WakeLocked(); +} + +void +StorageServerInstance::WakeLocked() +{ + // Start server in-place using existing data if (!m_IsHibernated) { diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index f5ba30616..bf7afcc02 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -185,7 +185,7 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) { const HttpRange& Range = Ranges.front(); const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = Range.Start < BlobSize ? Range.Start - BlobSize : 0; + const uint64_t MaxBlobSize = Range.Start < BlobSize ? BlobSize - Range.Start : 0; const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize); if (Range.Start + RangeSize > BlobSize) { diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index 04a0781d3..aa37e75fe 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -266,13 +266,12 @@ BuildStore::PutBlob(const IoHash& BlobHash, const IoBuffer& Payload) m_BlobLookup.insert({BlobHash, NewBlobIndex}); } - m_LastAccessTimeUpdateCount++; if (m_TrackedBlobKeys) { m_TrackedBlobKeys->push_back(BlobHash); if (MetadataHash != IoHash::Zero) { - m_TrackedBlobKeys->push_back(BlobHash); + m_TrackedBlobKeys->push_back(MetadataHash); } } } diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index ead7e4f3a..b73b3e6fb 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -626,7 +626,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B return false; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(BucketMetaHeader))) / sizeof(ManifestData); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(BucketMetaHeader)) / sizeof(ManifestData); if (Header.EntryCount > ExpectedEntryCount) { ZEN_WARN( @@ -1057,7 +1057,7 @@ ZenCacheDiskLayer::CacheBucket::ReadIndexFile(RwLock::ExclusiveLockScope&, const return 0; } - const uint64_t ExpectedEntryCount = (FileSize - sizeof(sizeof(cache::impl::CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); + const uint64_t ExpectedEntryCount = (FileSize - sizeof(cache::impl::CacheBucketIndexHeader)) / sizeof(DiskIndexEntry); if (Header.EntryCount > ExpectedEntryCount) { return 0; diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index 94abcf547..e1fd0a3e6 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -966,7 +966,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb } else { - ResponseObject.AddBool(true); + ResponseObject.AddBool(false); } } ResponseObject.EndArray(); diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 52b494e45..4e8475293 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -608,7 +608,10 @@ ZenCacheStore::GetBatch::Commit() m_CacheStore.m_HitCount++; OpScope.SetBytes(Result.Value.GetSize()); } - m_CacheStore.m_MissCount++; + else + { + m_CacheStore.m_MissCount++; + } } } } diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp index ed017988f..7402d92d3 100644 --- a/src/zenstore/cas.cpp +++ b/src/zenstore/cas.cpp @@ -300,12 +300,12 @@ GetCompactCasResults(CasContainerStrategy& Strategy, }; static void -GetFileCasResults(FileCasStrategy& Strategy, - CasStore::InsertMode Mode, - std::span Data, - std::span ChunkHashes, - std::span Indexes, - std::vector Results) +GetFileCasResults(FileCasStrategy& Strategy, + CasStore::InsertMode Mode, + std::span Data, + std::span ChunkHashes, + std::span Indexes, + std::vector& Results) { for (size_t Index : Indexes) { diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua index db3ff2e2d..844ba82ef 100644 --- a/src/zentest-appstub/xmake.lua +++ b/src/zentest-appstub/xmake.lua @@ -6,8 +6,6 @@ target("zentest-appstub") add_headerfiles("**.h") add_files("*.cpp") add_deps("zencore") - add_packages("vcpkg::gsl-lite") -- this should ideally be propagated by the zencore dependency - add_packages("vcpkg::mimalloc") if is_os("linux") then add_syslinks("pthread") -- cgit v1.2.3 From 3cfc1b18f6b86b9830730f0055b8e3b955b77c95 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 24 Feb 2026 15:36:59 +0100 Subject: Add `zen ui` command (#779) Allows user to automate launching of zenserver dashboard, including when multiple instances are running. If multiple instances are running you can open all dashboards with `--all`, and also using the in-terminal chooser which also allows you to open a specific instance. Also includes a fix to `zen exec` when using offset/stride/limit --- src/zen/cmds/exec_cmd.cpp | 9 +- src/zen/cmds/ui_cmd.cpp | 236 +++++++++++++++ src/zen/cmds/ui_cmd.h | 32 ++ src/zen/progressbar.cpp | 55 +--- src/zen/progressbar.h | 1 - src/zen/zen.cpp | 6 +- src/zencore/include/zencore/process.h | 1 + src/zencore/process.cpp | 226 +++++++++++++++ src/zenutil/consoletui.cpp | 483 +++++++++++++++++++++++++++++++ src/zenutil/include/zenutil/consoletui.h | 59 ++++ 10 files changed, 1054 insertions(+), 54 deletions(-) create mode 100644 src/zen/cmds/ui_cmd.cpp create mode 100644 src/zen/cmds/ui_cmd.h create mode 100644 src/zenutil/consoletui.cpp create mode 100644 src/zenutil/include/zenutil/consoletui.h (limited to 'src') diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 2d9d0d12e..407f42ee3 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -360,6 +360,13 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess return false; }; + int TargetParallelism = 8; + + if (OffsetCounter || StrideCounter || m_Limit) + { + TargetParallelism = 1; + } + m_RecordingReader->IterateActions( [&](CbObject ActionObject, const IoHash& ActionId) { // Enqueue job @@ -444,7 +451,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess DrainCompletedJobs(); }, - 8); + TargetParallelism); // Wait until all pending work is complete diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp new file mode 100644 index 000000000..da06ce305 --- /dev/null +++ b/src/zen/cmds/ui_cmd.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "ui_cmd.h" + +#include +#include +#include +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# include +#endif + +namespace zen { + +namespace { + + struct RunningServerInfo + { + uint16_t Port; + uint32_t Pid; + std::string SessionId; + std::string CmdLine; + }; + + static std::vector CollectRunningServers() + { + std::vector Servers; + ZenServerState State; + if (!State.InitializeReadOnly()) + return Servers; + + State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { + StringBuilder<25> SessionSB; + Entry.GetSessionId().ToString(SessionSB); + std::error_code CmdLineEc; + std::string CmdLine = GetProcessCommandLine(static_cast(Entry.Pid.load()), CmdLineEc); + Servers.push_back({Entry.EffectiveListenPort.load(), Entry.Pid.load(), std::string(SessionSB.c_str()), std::move(CmdLine)}); + }); + + return Servers; + } + +} // namespace + +UiCommand::UiCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_options()("a,all", "Open dashboard for all running instances", cxxopts::value(m_All)->default_value("false")); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), ""); + m_Options.add_option("", + "p", + "path", + "Dashboard path (default: /dashboard/)", + cxxopts::value(m_DashboardPath)->default_value("/dashboard/"), + ""); + m_Options.parse_positional("path"); +} + +UiCommand::~UiCommand() +{ +} + +void +UiCommand::OpenBrowser(std::string_view HostName) +{ + // Allow shortcuts for specifying dashboard path, and ensure it is in a format we expect + // (leading slash, trailing slash if no file extension) + + if (!m_DashboardPath.empty()) + { + if (m_DashboardPath[0] != '/') + { + m_DashboardPath = "/dashboard/" + m_DashboardPath; + } + + if (m_DashboardPath.find_last_of('.') == std::string::npos && m_DashboardPath.back() != '/') + { + m_DashboardPath += '/'; + } + } + + bool Success = false; + + ExtendableStringBuilder<256> FullUrl; + FullUrl << HostName << m_DashboardPath; + +#if ZEN_PLATFORM_WINDOWS + HINSTANCE Result = ShellExecuteA(nullptr, "open", FullUrl.c_str(), nullptr, nullptr, SW_SHOWNORMAL); + Success = reinterpret_cast(Result) > 32; +#else + // Validate URL doesn't contain shell metacharacters that could lead to command injection + std::string_view FullUrlView = FullUrl; + constexpr std::string_view DangerousChars = ";|&$`\\\"'<>(){}[]!#*?~\n\r"; + if (FullUrlView.find_first_of(DangerousChars) != std::string_view::npos) + { + throw OptionParseException(fmt::format("URL contains invalid characters: '{}'", FullUrl), m_Options.help()); + } + +# if ZEN_PLATFORM_MAC + std::string Command = fmt::format("open \"{}\"", FullUrl); +# elif ZEN_PLATFORM_LINUX + std::string Command = fmt::format("xdg-open \"{}\"", FullUrl); +# else + ZEN_NOT_IMPLEMENTED("Browser launching not implemented on this platform"); +# endif + + Success = system(Command.c_str()) == 0; +#endif + + if (!Success) + { + throw zen::runtime_error("Failed to launch browser for '{}'", FullUrl); + } + + ZEN_CONSOLE("Web browser launched for '{}' successfully", FullUrl); +} + +void +UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + using namespace std::literals; + + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return; + } + + // Resolve target server + uint16_t ServerPort = 0; + + if (m_HostName.empty()) + { + // Auto-discover running instances. + std::vector Servers = CollectRunningServers(); + + if (m_All) + { + if (Servers.empty()) + { + throw OptionParseException("No running Zen server instances found", m_Options.help()); + } + + for (const auto& Server : Servers) + { + OpenBrowser(fmt::format("http://localhost:{}", Server.Port)); + } + return; + } + + // If multiple are found and we have an interactive terminal, present a picker + // instead of silently using the first one. + if (Servers.size() > 1 && IsTuiAvailable()) + { + std::vector Labels; + Labels.reserve(Servers.size() + 1); + Labels.push_back(fmt::format("(all {} instances)", Servers.size())); + + const int32_t Cols = static_cast(TuiConsoleColumns()); + constexpr int32_t kIndicator = 3; // " ▶ " or " " prefix + constexpr int32_t kSeparator = 2; // " " before cmdline + constexpr int32_t kEllipsis = 3; // "..." + + for (const auto& Server : Servers) + { + std::string Label = fmt::format("port {:<5} pid {:<7} session {}", Server.Port, Server.Pid, Server.SessionId); + + if (!Server.CmdLine.empty()) + { + int32_t Available = Cols - kIndicator - kSeparator - static_cast(Label.size()); + if (Available > kEllipsis) + { + Label += " "; + if (static_cast(Server.CmdLine.size()) <= Available) + { + Label += Server.CmdLine; + } + else + { + Label.append(Server.CmdLine, 0, static_cast(Available - kEllipsis)); + Label += "..."; + } + } + } + + Labels.push_back(std::move(Label)); + } + + int SelectedIdx = TuiPickOne("Multiple Zen server instances found. Select one to open:", Labels); + if (SelectedIdx < 0) + return; // User cancelled + + if (SelectedIdx == 0) + { + // "All" selected + for (const auto& Server : Servers) + { + OpenBrowser(fmt::format("http://localhost:{}", Server.Port)); + } + return; + } + + ServerPort = Servers[SelectedIdx - 1].Port; + m_HostName = fmt::format("http://localhost:{}", ServerPort); + } + + if (m_HostName.empty()) + { + // Single or zero instances, or not an interactive terminal: + // fall back to default resolution (picks first instance or returns empty) + m_HostName = ResolveTargetHostSpec("", ServerPort); + } + } + else + { + if (m_All) + { + throw OptionParseException("--all cannot be used together with --hosturl", m_Options.help()); + } + m_HostName = ResolveTargetHostSpec(m_HostName, ServerPort); + } + + if (m_HostName.empty()) + { + throw OptionParseException("Unable to resolve server specification", m_Options.help()); + } + + OpenBrowser(m_HostName); +} + +} // namespace zen diff --git a/src/zen/cmds/ui_cmd.h b/src/zen/cmds/ui_cmd.h new file mode 100644 index 000000000..c74cdbbd0 --- /dev/null +++ b/src/zen/cmds/ui_cmd.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +#include + +namespace zen { + +class UiCommand : public ZenCmdBase +{ +public: + UiCommand(); + ~UiCommand(); + + static constexpr char Name[] = "ui"; + static constexpr char Description[] = "Launch web browser with zen server UI"; + + virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + void OpenBrowser(std::string_view HostName); + + cxxopts::Options m_Options{Name, Description}; + std::string m_HostName; + std::string m_DashboardPath = "/dashboard/"; + bool m_All = false; +}; + +} // namespace zen diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 1ee1d1e71..732f16e81 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -8,16 +8,12 @@ #include #include #include +#include ZEN_THIRD_PARTY_INCLUDES_START #include ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -# include -# include -#endif - ////////////////////////////////////////////////////////////////////////// namespace zen { @@ -31,35 +27,12 @@ GetConsoleHandle() } #endif -static bool -CheckStdoutTty() -{ -#if ZEN_PLATFORM_WINDOWS - HANDLE hStdOut = GetConsoleHandle(); - DWORD dwMode = 0; - static bool IsConsole = ::GetConsoleMode(hStdOut, &dwMode); - return IsConsole; -#else - return isatty(fileno(stdout)); -#endif -} - -static bool -IsStdoutTty() -{ - static bool StdoutIsTty = CheckStdoutTty(); - return StdoutIsTty; -} - static void OutputToConsoleRaw(const char* String, size_t Length) { #if ZEN_PLATFORM_WINDOWS HANDLE hStdOut = GetConsoleHandle(); -#endif - -#if ZEN_PLATFORM_WINDOWS - if (IsStdoutTty()) + if (TuiIsStdoutTty()) { WriteConsoleA(hStdOut, String, (DWORD)Length, 0, 0); } @@ -84,26 +57,6 @@ OutputToConsoleRaw(const StringBuilderBase& SB) OutputToConsoleRaw(SB.c_str(), SB.Size()); } -uint32_t -GetConsoleColumns(uint32_t Default) -{ -#if ZEN_PLATFORM_WINDOWS - HANDLE hStdOut = GetConsoleHandle(); - CONSOLE_SCREEN_BUFFER_INFO csbi; - if (GetConsoleScreenBufferInfo(hStdOut, &csbi) == TRUE) - { - return (uint32_t)(csbi.srWindow.Right - csbi.srWindow.Left + 1); - } -#else - struct winsize w; - if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &w) == 0) - { - return (uint32_t)w.ws_col; - } -#endif - return Default; -} - uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode) { @@ -165,7 +118,7 @@ ProgressBar::PopLogOperation(Mode InMode) } ProgressBar::ProgressBar(Mode InMode, std::string_view InSubTask) -: m_Mode((!IsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode) +: m_Mode((!TuiIsStdoutTty() && InMode == Mode::Pretty) ? Mode::Plain : InMode) , m_LastUpdateMS((uint64_t)-1) , m_PausedMS(0) , m_SubTask(InSubTask) @@ -257,7 +210,7 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) uint64_t ETAMS = (NewState.Status == State::EStatus::Running) && (PercentDone > 5) ? (ETAElapsedMS * NewState.RemainingCount) / Completed : 0; - uint32_t ConsoleColumns = GetConsoleColumns(1024); + uint32_t ConsoleColumns = TuiConsoleColumns(1024); const std::string PercentString = fmt::format("{:#3}%", PercentDone); diff --git a/src/zen/progressbar.h b/src/zen/progressbar.h index bbdb008d4..cb1c7023b 100644 --- a/src/zen/progressbar.h +++ b/src/zen/progressbar.h @@ -76,7 +76,6 @@ private: }; uint32_t GetUpdateDelayMS(ProgressBar::Mode InMode); -uint32_t GetConsoleColumns(uint32_t Default); OperationLogOutput* CreateConsoleLogOutput(ProgressBar::Mode InMode); diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index bdc2b4003..dc37cb56b 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -22,6 +22,7 @@ #include "cmds/status_cmd.h" #include "cmds/top_cmd.h" #include "cmds/trace_cmd.h" +#include "cmds/ui_cmd.h" #include "cmds/up_cmd.h" #include "cmds/version_cmd.h" #include "cmds/vfs_cmd.h" @@ -41,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -123,7 +125,7 @@ ZenCmdBase::ParseOptions(int argc, char** argv) bool ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv) { - CmdOptions.set_width(GetConsoleColumns(80)); + CmdOptions.set_width(TuiConsoleColumns(80)); cxxopts::ParseResult Result; @@ -364,6 +366,7 @@ main(int argc, char** argv) LoggingCommand LoggingCmd; TopCommand TopCmd; TraceCommand TraceCmd; + UiCommand UiCmd; UpCommand UpCmd; VersionCommand VersionCmd; VfsCommand VfsCmd; @@ -425,6 +428,7 @@ main(int argc, char** argv) {StatusCommand::Name, &StatusCmd, StatusCommand::Description}, {TopCommand::Name, &TopCmd, TopCommand::Description}, {TraceCommand::Name, &TraceCmd, TraceCommand::Description}, + {UiCommand::Name, &UiCmd, UiCommand::Description}, {UpCommand::Name, &UpCmd, UpCommand::Description}, {VersionCommand::Name, &VersionCmd, VersionCommand::Description}, {VfsCommand::Name, &VfsCmd, VfsCommand::Description}, diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index e3b7a70d7..c51163a68 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -105,6 +105,7 @@ int GetCurrentProcessId(); int GetProcessId(CreateProcResult ProcId); std::filesystem::path GetProcessExecutablePath(int Pid, std::error_code& OutEc); +std::string GetProcessCommandLine(int Pid, std::error_code& OutEc); std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf = true); /** Wait for all threads in the current process to exit (except the calling thread) diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 56849a10d..4a2668912 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -1001,6 +1001,232 @@ GetProcessExecutablePath(int Pid, std::error_code& OutEc) #endif // ZEN_PLATFORM_LINUX } +std::string +GetProcessCommandLine(int Pid, std::error_code& OutEc) +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE hProcess = OpenProcess(PROCESS_QUERY_INFORMATION, FALSE, static_cast(Pid)); + if (!hProcess) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + auto _ = MakeGuard([hProcess] { CloseHandle(hProcess); }); + + // NtQueryInformationProcess is an undocumented NT API; load it dynamically. + // Info class 60 = ProcessCommandLine, available since Windows 8.1. + using PFN_NtQIP = LONG(WINAPI*)(HANDLE, UINT, PVOID, ULONG, PULONG); + static const PFN_NtQIP s_NtQIP = + reinterpret_cast(GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "NtQueryInformationProcess")); + if (!s_NtQIP) + { + return {}; + } + + constexpr UINT ProcessCommandLineClass = 60; + constexpr LONG StatusInfoLengthMismatch = static_cast(0xC0000004L); + + ULONG ReturnLength = 0; + LONG Status = s_NtQIP(hProcess, ProcessCommandLineClass, nullptr, 0, &ReturnLength); + if (Status != StatusInfoLengthMismatch || ReturnLength == 0) + { + return {}; + } + + std::vector Buf(ReturnLength); + Status = s_NtQIP(hProcess, ProcessCommandLineClass, Buf.data(), ReturnLength, &ReturnLength); + if (Status < 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + // Output: UNICODE_STRING header immediately followed by the UTF-16 string data. + // The UNICODE_STRING.Buffer field points into our Buf. + struct LocalUnicodeString + { + USHORT Length; + USHORT MaximumLength; + WCHAR* Buffer; + }; + if (ReturnLength < sizeof(LocalUnicodeString)) + { + return {}; + } + const auto* Us = reinterpret_cast(Buf.data()); + if (Us->Length == 0 || Us->Buffer == nullptr) + { + return {}; + } + + // Skip argv[0]: may be a quoted path ("C:\...\exe.exe") or a bare path + const WCHAR* p = Us->Buffer; + const WCHAR* End = Us->Buffer + Us->Length / sizeof(WCHAR); + if (p < End && *p == L'"') + { + ++p; + while (p < End && *p != L'"') + { + ++p; + } + if (p < End) + { + ++p; // skip closing quote + } + } + else + { + while (p < End && *p != L' ') + { + ++p; + } + } + while (p < End && *p == L' ') + { + ++p; + } + if (p >= End) + { + return {}; + } + + int Utf8Size = WideCharToMultiByte(CP_UTF8, 0, p, static_cast(End - p), nullptr, 0, nullptr, nullptr); + if (Utf8Size <= 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + std::string Result(Utf8Size, '\0'); + WideCharToMultiByte(CP_UTF8, 0, p, static_cast(End - p), Result.data(), Utf8Size, nullptr, nullptr); + return Result; + +#elif ZEN_PLATFORM_LINUX + std::string CmdlinePath = fmt::format("/proc/{}/cmdline", Pid); + FILE* F = fopen(CmdlinePath.c_str(), "rb"); + if (!F) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + auto FGuard = MakeGuard([F] { fclose(F); }); + + // /proc/{pid}/cmdline contains null-separated argv entries; read it all + std::string Raw; + char Chunk[4096]; + size_t BytesRead; + while ((BytesRead = fread(Chunk, 1, sizeof(Chunk), F)) > 0) + { + Raw.append(Chunk, BytesRead); + } + if (Raw.empty()) + { + return {}; + } + + // Skip argv[0] (first null-terminated entry) + const char* p = Raw.data(); + const char* End = Raw.data() + Raw.size(); + while (p < End && *p != '\0') + { + ++p; + } + if (p < End) + { + ++p; // skip null terminator of argv[0] + } + + // Build result: remaining entries joined by spaces (inter-arg nulls → spaces) + std::string Result; + Result.reserve(static_cast(End - p)); + for (const char* q = p; q < End; ++q) + { + Result += (*q == '\0') ? ' ' : *q; + } + while (!Result.empty() && Result.back() == ' ') + { + Result.pop_back(); + } + return Result; + +#elif ZEN_PLATFORM_MAC + int Mib[3] = {CTL_KERN, KERN_PROCARGS2, Pid}; + size_t BufSize = 0; + if (sysctl(Mib, 3, nullptr, &BufSize, nullptr, 0) != 0 || BufSize == 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + std::vector Buf(BufSize); + if (sysctl(Mib, 3, Buf.data(), &BufSize, nullptr, 0) != 0) + { + OutEc = MakeErrorCodeFromLastError(); + return {}; + } + + // Layout: [int argc][exec_path\0][null padding][argv[0]\0][argv[1]\0]...[envp\0]... + if (BufSize < sizeof(int)) + { + return {}; + } + int Argc = 0; + memcpy(&Argc, Buf.data(), sizeof(int)); + if (Argc <= 1) + { + return {}; + } + + const char* p = Buf.data() + sizeof(int); + const char* End = Buf.data() + BufSize; + + // Skip exec_path and any trailing null padding that follows it + while (p < End && *p != '\0') + { + ++p; + } + while (p < End && *p == '\0') + { + ++p; + } + + // Skip argv[0] + while (p < End && *p != '\0') + { + ++p; + } + if (p < End) + { + ++p; + } + + // Collect argv[1..Argc-1] + std::string Result; + for (int i = 1; i < Argc && p < End; ++i) + { + if (i > 1) + { + Result += ' '; + } + const char* ArgStart = p; + while (p < End && *p != '\0') + { + ++p; + } + Result.append(ArgStart, p); + if (p < End) + { + ++p; + } + } + return Result; + +#else + ZEN_UNUSED(Pid); + ZEN_UNUSED(OutEc); + return {}; +#endif +} + std::error_code FindProcess(const std::filesystem::path& ExecutableImage, ProcessHandle& OutHandle, bool IncludeSelf) { diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp new file mode 100644 index 000000000..4410d463d --- /dev/null +++ b/src/zenutil/consoletui.cpp @@ -0,0 +1,483 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include + +#if ZEN_PLATFORM_WINDOWS +# include +#else +# include +# include +# include +# include +#endif + +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// Platform-specific terminal helpers + +#if ZEN_PLATFORM_WINDOWS + +static bool +CheckIsInteractiveTerminal() +{ + DWORD dwMode = 0; + return GetConsoleMode(GetStdHandle(STD_INPUT_HANDLE), &dwMode) && GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode); +} + +static void +EnableVirtualTerminal() +{ + HANDLE hStdOut = GetStdHandle(STD_OUTPUT_HANDLE); + DWORD dwMode = 0; + if (GetConsoleMode(hStdOut, &dwMode)) + { + SetConsoleMode(hStdOut, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); + } +} + +// RAII guard: sets the console output code page for the lifetime of the object and +// restores the original on destruction. Required for UTF-8 glyphs to render correctly +// via printf/fflush since the default console code page is not UTF-8. +class ConsoleCodePageGuard +{ +public: + explicit ConsoleCodePageGuard(UINT NewCP) : m_OldCP(GetConsoleOutputCP()) { SetConsoleOutputCP(NewCP); } + ~ConsoleCodePageGuard() { SetConsoleOutputCP(m_OldCP); } + +private: + UINT m_OldCP; +}; + +enum class ConsoleKey +{ + Unknown, + ArrowUp, + ArrowDown, + Enter, + Escape, +}; + +static ConsoleKey +ReadKey() +{ + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + INPUT_RECORD Record{}; + DWORD dwRead = 0; + while (true) + { + if (!ReadConsoleInputA(hStdin, &Record, 1, &dwRead)) + { + return ConsoleKey::Escape; // treat read error as cancel + } + if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown) + { + switch (Record.Event.KeyEvent.wVirtualKeyCode) + { + case VK_UP: + return ConsoleKey::ArrowUp; + case VK_DOWN: + return ConsoleKey::ArrowDown; + case VK_RETURN: + return ConsoleKey::Enter; + case VK_ESCAPE: + return ConsoleKey::Escape; + default: + break; + } + } + } +} + +#else // POSIX + +static bool +CheckIsInteractiveTerminal() +{ + return isatty(STDIN_FILENO) && isatty(STDOUT_FILENO); +} + +static void +EnableVirtualTerminal() +{ + // ANSI escape codes are native on POSIX terminals; nothing to do +} + +// RAII guard: switches the terminal to raw/unbuffered input mode and restores +// the original attributes on destruction. +class RawModeGuard +{ +public: + RawModeGuard() + { + if (tcgetattr(STDIN_FILENO, &m_OldAttrs) != 0) + { + return; + } + + struct termios Raw = m_OldAttrs; + Raw.c_iflag &= ~static_cast(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + Raw.c_cflag |= CS8; + Raw.c_lflag &= ~static_cast(ECHO | ICANON | IEXTEN | ISIG); + Raw.c_cc[VMIN] = 1; + Raw.c_cc[VTIME] = 0; + if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0) + { + m_Valid = true; + } + } + + ~RawModeGuard() + { + if (m_Valid) + { + tcsetattr(STDIN_FILENO, TCSANOW, &m_OldAttrs); + } + } + + bool IsValid() const { return m_Valid; } + +private: + struct termios m_OldAttrs = {}; + bool m_Valid = false; +}; + +static int +ReadByteWithTimeout(int TimeoutMs) +{ + struct pollfd Pfd + { + STDIN_FILENO, POLLIN, 0 + }; + if (poll(&Pfd, 1, TimeoutMs) > 0 && (Pfd.revents & POLLIN)) + { + unsigned char c = 0; + if (read(STDIN_FILENO, &c, 1) == 1) + { + return static_cast(c); + } + } + return -1; +} + +// State for fullscreen live mode (alternate screen + raw input) +static struct termios s_SavedAttrs = {}; +static bool s_InLiveMode = false; + +enum class ConsoleKey +{ + Unknown, + ArrowUp, + ArrowDown, + Enter, + Escape, +}; + +static ConsoleKey +ReadKey() +{ + unsigned char c = 0; + if (read(STDIN_FILENO, &c, 1) != 1) + { + return ConsoleKey::Escape; // treat read error as cancel + } + + if (c == 27) // ESC byte or start of an escape sequence + { + int Next = ReadByteWithTimeout(50); + if (Next == '[') + { + int Final = ReadByteWithTimeout(50); + if (Final == 'A') + { + return ConsoleKey::ArrowUp; + } + if (Final == 'B') + { + return ConsoleKey::ArrowDown; + } + } + return ConsoleKey::Escape; + } + + if (c == '\r' || c == '\n') + { + return ConsoleKey::Enter; + } + + return ConsoleKey::Unknown; +} + +#endif // ZEN_PLATFORM_WINDOWS / POSIX + +////////////////////////////////////////////////////////////////////////// +// Public API + +uint32_t +TuiConsoleColumns(uint32_t Default) +{ +#if ZEN_PLATFORM_WINDOWS + CONSOLE_SCREEN_BUFFER_INFO Csbi = {}; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi)) + { + return static_cast(Csbi.dwSize.X); + } +#else + struct winsize Ws = {}; + if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_col > 0) + { + return static_cast(Ws.ws_col); + } +#endif + return Default; +} + +void +TuiEnableOutput() +{ + EnableVirtualTerminal(); +#if ZEN_PLATFORM_WINDOWS + SetConsoleOutputCP(CP_UTF8); +#endif +} + +bool +TuiIsStdoutTty() +{ +#if ZEN_PLATFORM_WINDOWS + static bool Cached = [] { + DWORD dwMode = 0; + return GetConsoleMode(GetStdHandle(STD_OUTPUT_HANDLE), &dwMode) != 0; + }(); + return Cached; +#else + static bool Cached = isatty(STDOUT_FILENO) != 0; + return Cached; +#endif +} + +bool +IsTuiAvailable() +{ + static bool Cached = CheckIsInteractiveTerminal(); + return Cached; +} + +int +TuiPickOne(std::string_view Title, std::span Items) +{ + EnableVirtualTerminal(); + +#if ZEN_PLATFORM_WINDOWS + ConsoleCodePageGuard CodePageGuard(CP_UTF8); +#else + RawModeGuard RawMode; + if (!RawMode.IsValid()) + { + return -1; + } +#endif + + const int Count = static_cast(Items.size()); + int SelectedIndex = 0; + + printf("\n%.*s\n\n", static_cast(Title.size()), Title.data()); + + // Hide cursor during interaction + printf("\033[?25l"); + + // Renders the full entry list and hint footer. + // On subsequent calls, moves the cursor back up first to overwrite the previous output. + bool FirstRender = true; + auto RenderAll = [&] { + if (!FirstRender) + { + printf("\033[%dA", Count + 2); // move up: entries + blank line + hint line + } + FirstRender = false; + + for (int i = 0; i < Count; ++i) + { + bool IsSelected = (i == SelectedIndex); + + printf("\r\033[K"); // erase line + + if (IsSelected) + { + printf("\033[1;7m"); // bold + reverse video + } + + // \xe2\x96\xb6 = U+25B6 BLACK RIGHT-POINTING TRIANGLE (▶) + const char* Indicator = IsSelected ? " \xe2\x96\xb6 " : " "; + + printf("%s%s", Indicator, Items[i].c_str()); + + if (IsSelected) + { + printf("\033[0m"); // reset attributes + } + + printf("\n"); + } + + // Blank separator line + printf("\r\033[K\n"); + + // Hint footer + // \xe2\x86\x91 = U+2191 ↑ \xe2\x86\x93 = U+2193 ↓ + printf( + "\r\033[K \033[2m\xe2\x86\x91/\xe2\x86\x93\033[0m navigate " + "\033[2mEnter\033[0m confirm " + "\033[2mEsc\033[0m cancel\n"); + + fflush(stdout); + }; + + RenderAll(); + + int Result = -1; + bool Done = false; + while (!Done) + { + ConsoleKey Key = ReadKey(); + switch (Key) + { + case ConsoleKey::ArrowUp: + SelectedIndex = (SelectedIndex - 1 + Count) % Count; + RenderAll(); + break; + + case ConsoleKey::ArrowDown: + SelectedIndex = (SelectedIndex + 1) % Count; + RenderAll(); + break; + + case ConsoleKey::Enter: + Result = SelectedIndex; + Done = true; + break; + + case ConsoleKey::Escape: + Done = true; + break; + + default: + break; + } + } + + // Restore cursor and add a blank line for visual separation + printf("\033[?25h\n"); + fflush(stdout); + + return Result; +} + +void +TuiEnterAlternateScreen() +{ + EnableVirtualTerminal(); +#if ZEN_PLATFORM_WINDOWS + SetConsoleOutputCP(CP_UTF8); +#endif + + printf("\033[?1049h"); // Enter alternate screen buffer + printf("\033[?25l"); // Hide cursor + fflush(stdout); + +#if !ZEN_PLATFORM_WINDOWS + if (tcgetattr(STDIN_FILENO, &s_SavedAttrs) == 0) + { + struct termios Raw = s_SavedAttrs; + Raw.c_iflag &= ~static_cast(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + Raw.c_cflag |= CS8; + Raw.c_lflag &= ~static_cast(ECHO | ICANON | IEXTEN | ISIG); + Raw.c_cc[VMIN] = 1; + Raw.c_cc[VTIME] = 0; + if (tcsetattr(STDIN_FILENO, TCSANOW, &Raw) == 0) + { + s_InLiveMode = true; + } + } +#endif +} + +void +TuiExitAlternateScreen() +{ + printf("\033[?25h"); // Show cursor + printf("\033[?1049l"); // Exit alternate screen buffer + fflush(stdout); + +#if !ZEN_PLATFORM_WINDOWS + if (s_InLiveMode) + { + tcsetattr(STDIN_FILENO, TCSANOW, &s_SavedAttrs); + s_InLiveMode = false; + } +#endif +} + +void +TuiCursorHome() +{ + printf("\033[H"); +} + +uint32_t +TuiConsoleRows(uint32_t Default) +{ +#if ZEN_PLATFORM_WINDOWS + CONSOLE_SCREEN_BUFFER_INFO Csbi = {}; + if (GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &Csbi)) + { + return static_cast(Csbi.srWindow.Bottom - Csbi.srWindow.Top + 1); + } +#else + struct winsize Ws = {}; + if (ioctl(STDOUT_FILENO, TIOCGWINSZ, &Ws) == 0 && Ws.ws_row > 0) + { + return static_cast(Ws.ws_row); + } +#endif + return Default; +} + +bool +TuiPollQuit() +{ +#if ZEN_PLATFORM_WINDOWS + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + DWORD dwCount = 0; + if (!GetNumberOfConsoleInputEvents(hStdin, &dwCount) || dwCount == 0) + { + return false; + } + INPUT_RECORD Record{}; + DWORD dwRead = 0; + while (PeekConsoleInputA(hStdin, &Record, 1, &dwRead) && dwRead > 0) + { + ReadConsoleInputA(hStdin, &Record, 1, &dwRead); + if (Record.EventType == KEY_EVENT && Record.Event.KeyEvent.bKeyDown) + { + WORD vk = Record.Event.KeyEvent.wVirtualKeyCode; + char ch = Record.Event.KeyEvent.uChar.AsciiChar; + if (vk == VK_ESCAPE || ch == 'q' || ch == 'Q') + { + return true; + } + } + } + return false; +#else + // Non-blocking read: character 3 = Ctrl+C, 27 = Esc, 'q'/'Q' = quit + int b = ReadByteWithTimeout(0); + return (b == 3 || b == 27 || b == 'q' || b == 'Q'); +#endif +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h new file mode 100644 index 000000000..7dc68c126 --- /dev/null +++ b/src/zenutil/include/zenutil/consoletui.h @@ -0,0 +1,59 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include + +namespace zen { + +// Returns the width of the console in columns, or Default if it cannot be determined. +uint32_t TuiConsoleColumns(uint32_t Default = 120); + +// Enables ANSI/VT escape code processing and UTF-8 console output. +// Call once before printing ANSI escape sequences or multi-byte UTF-8 characters via printf. +// Safe to call multiple times. No-op on POSIX (escape codes are native there). +void TuiEnableOutput(); + +// Returns true if stdout is connected to a real terminal (not piped or redirected). +// Useful for deciding whether to use ANSI escape codes for progress output. +bool TuiIsStdoutTty(); + +// Returns true if both stdin and stdout are connected to an interactive terminal +// (i.e. not piped or redirected). Must be checked before calling TuiPickOne(). +bool IsTuiAvailable(); + +// Displays a cursor-navigable single-select list in the terminal. +// +// - Title: a short description printed once above the list +// - Items: pre-formatted display labels, one per selectable entry +// +// Arrow keys (↑/↓) navigate the selection, Enter confirms, Esc cancels. +// Returns the index of the selected item, or -1 if the user cancelled. +// +// Precondition: IsTuiAvailable() must be true. +int TuiPickOne(std::string_view Title, std::span Items); + +// Enter the alternate screen buffer for fullscreen live-update mode. +// Hides the cursor. On POSIX, switches to raw/unbuffered terminal input. +// Must be balanced by a call to TuiExitAlternateScreen(). +// Precondition: IsTuiAvailable() must be true. +void TuiEnterAlternateScreen(); + +// Exit alternate screen buffer. Restores the cursor and, on POSIX, the original +// terminal mode. Safe to call even if TuiEnterAlternateScreen() was not called. +void TuiExitAlternateScreen(); + +// Move the cursor to the top-left corner of the terminal (row 1, col 1). +void TuiCursorHome(); + +// Returns the height of the console in rows, or Default if it cannot be determined. +uint32_t TuiConsoleRows(uint32_t Default = 40); + +// Non-blocking check: returns true if the user has pressed a key that means quit +// (Esc, 'q', 'Q', or Ctrl+C). Consumes the event if one is pending. +// Should only be called while in alternate screen mode. +bool TuiPollQuit(); + +} // namespace zen -- cgit v1.2.3 From eb3079e2ec2969829cbc5b6921575d53df351f0f Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Tue, 24 Feb 2026 16:10:36 +0100 Subject: use partial blocks for oplog import (#780) Feature: Add --allow-partial-block-requests to zen oplog-import Improvement: zen oplog-import now uses partial block requests to reduce download size Improvement: Use latency to Cloud Storage host and Zen Cache host when calculating partial block requests --- src/zen/cmds/builds_cmd.cpp | 28 +- src/zen/cmds/projectstore_cmd.cpp | 28 +- src/zen/cmds/projectstore_cmd.h | 2 + src/zenhttp/httpclient.cpp | 38 + src/zenhttp/include/zenhttp/httpclient.h | 9 + src/zenremotestore/builds/buildstoragecache.cpp | 8 +- .../builds/buildstorageoperations.cpp | 45 +- src/zenremotestore/builds/buildstorageutil.cpp | 19 +- src/zenremotestore/chunking/chunkblock.cpp | 79 +- .../zenremotestore/builds/buildstoragecache.h | 1 + .../zenremotestore/builds/buildstorageoperations.h | 12 +- .../zenremotestore/builds/buildstorageutil.h | 4 + .../include/zenremotestore/chunking/chunkblock.h | 25 +- .../include/zenremotestore/jupiter/jupiterhost.h | 1 + .../include/zenremotestore/operationlogoutput.h | 5 +- .../zenremotestore/partialblockrequestmode.h | 20 + .../projectstore/buildsremoteprojectstore.h | 4 +- .../projectstore/remoteprojectstore.h | 67 +- src/zenremotestore/jupiter/jupiterhost.cpp | 8 +- src/zenremotestore/operationlogoutput.cpp | 2 +- src/zenremotestore/partialblockrequestmode.cpp | 27 + .../projectstore/buildsremoteprojectstore.cpp | 122 ++- .../projectstore/fileremoteprojectstore.cpp | 24 +- .../projectstore/jupiterremoteprojectstore.cpp | 19 +- .../projectstore/remoteprojectstore.cpp | 946 ++++++++++++++++----- .../projectstore/zenremoteprojectstore.cpp | 29 +- .../storage/projectstore/httpprojectstore.cpp | 33 +- 27 files changed, 1243 insertions(+), 362 deletions(-) create mode 100644 src/zenremotestore/include/zenremotestore/partialblockrequestmode.h create mode 100644 src/zenremotestore/partialblockrequestmode.cpp (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 849259013..5254ef3cf 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -2842,13 +2842,16 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) TempPath / "storage"); Result.StorageName = ResolveRes.HostName; - StorageDescription = fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'", - ResolveRes.HostName, - (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), - Result.BuildStorageHttp->GetSessionId(), - m_Namespace, - m_Bucket); - ; + uint64_t HostLatencyNs = ResolveRes.HostLatencySec >= 0 ? uint64_t(ResolveRes.HostLatencySec * 1000000000.0) : 0; + + StorageDescription = fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'. Latency: {}", + ResolveRes.HostName, + (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), + Result.BuildStorageHttp->GetSessionId(), + m_Namespace, + m_Bucket, + NiceLatencyNs(HostLatencyNs)); + Result.BuildStorageLatencySec = ResolveRes.HostLatencySec; if (!ResolveRes.CacheUrl.empty()) { @@ -2874,12 +2877,17 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) : GetTinyWorkerPool(EWorkloadType::Background)); Result.CacheName = ResolveRes.CacheName; + uint64_t CacheLatencyNs = ResolveRes.CacheLatencySec >= 0 ? uint64_t(ResolveRes.CacheLatencySec * 1000000000.0) : 0; + CacheDescription = - fmt::format("Zen {}{}. SessionId: '{}'", + fmt::format("Zen {}{}. SessionId: '{}'. Latency: {}", ResolveRes.CacheName, (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl), - Result.CacheHttp->GetSessionId()); - ; + Result.CacheHttp->GetSessionId(), + NiceLatencyNs(CacheLatencyNs)); + + Result.CacheLatencySec = ResolveRes.CacheLatencySec; + if (!m_Namespace.empty()) { CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index 4de6ad25c..bedab3cfd 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -1469,6 +1469,20 @@ ImportOplogCommand::ImportOplogCommand() "Enables both 'boost-worker-count' and 'boost-worker-memory' - may cause computer to be less responsive", cxxopts::value(m_BoostWorkers), ""); + m_Options.add_option( + "", + "", + "allow-partial-block-requests", + "Allow request for partial chunk blocks.\n" + " false = only full block requests allowed\n" + " mixed = multiple partial block ranges requests per block allowed to zen cache, single partial block range " + "request per block to host\n" + " zencacheonly = multiple partial block ranges requests per block allowed to zen cache, only full block requests " + "allowed to host\n" + " true = multiple partial block ranges requests per block allowed to zen cache and host\n" + "Defaults to 'mixed'.", + cxxopts::value(m_AllowPartialBlockRequests), + ""); m_Options.parse_positional({"project", "oplog", "gcpath"}); m_Options.positional_help("[ []]"); @@ -1513,6 +1527,13 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("'--oplog' is required", m_Options.help()); } + EPartialBlockRequestMode Mode = PartialBlockRequestModeFromString(m_AllowPartialBlockRequests); + if (Mode == EPartialBlockRequestMode::Invalid) + { + throw OptionParseException(fmt::format("'--allow-partial-block-requests' ('{}') is invalid", m_AllowPartialBlockRequests), + m_Options.help()); + } + HttpClient Http(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) @@ -1649,6 +1670,9 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg { Writer.AddBool("boostworkermemory"sv, true); } + + Writer.AddString("partialblockrequestmode", m_AllowPartialBlockRequests); + if (!m_FileDirectoryPath.empty()) { Writer.BeginObject("file"sv); @@ -2571,6 +2595,7 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2; ClientSettings.MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u; Storage.BuildStorageHttp = std::make_unique(ResolveRes.HostUrl, ClientSettings); + Storage.BuildStorageLatencySec = ResolveRes.HostLatencySec; BuildStorageCache::Statistics StorageCacheStats; @@ -2589,7 +2614,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a .RetryCount = 0, .MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u}, [&AbortFlag]() { return AbortFlag.load(); }); - Storage.CacheName = ResolveRes.CacheName; + Storage.CacheName = ResolveRes.CacheName; + Storage.CacheLatencySec = ResolveRes.CacheLatencySec; } if (!m_Quiet) diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h index e415b41b7..17fd76e9f 100644 --- a/src/zen/cmds/projectstore_cmd.h +++ b/src/zen/cmds/projectstore_cmd.h @@ -209,6 +209,8 @@ private: bool m_BoostWorkerCount = false; bool m_BoostWorkerMemory = false; bool m_BoostWorkers = false; + + std::string m_AllowPartialBlockRequests = "mixed"; }; class SnapshotOplogCommand : public ProjectStoreCommand diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index d3b59df2b..078e27b34 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -21,6 +21,8 @@ #include "clients/httpclientcommon.h" +#include + #if ZEN_WITH_TESTS # include # include @@ -340,6 +342,42 @@ HttpClient::Authenticate() return m_Inner->Authenticate(); } +LatencyTestResult +MeasureLatency(HttpClient& Client, std::string_view Url) +{ + std::vector MeasurementTimes; + std::string ErrorMessage; + + for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++) + { + HttpClient::Response MeasureResponse = Client.Get(Url); + if (MeasureResponse.IsSuccess()) + { + MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds); + Sleep(5); + } + else + { + ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); + } + } + + if (MeasurementTimes.empty()) + { + return {.Success = false, .FailureReason = ErrorMessage}; + } + + if (MeasurementTimes.size() > 2) + { + std::sort(MeasurementTimes.begin(), MeasurementTimes.end()); + MeasurementTimes.pop_back(); // Remove the worst time + } + + double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size(); + + return {.Success = true, .LatencySeconds = AverageLatency}; +} + ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 9a9b74d72..7a129a98c 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -260,6 +260,15 @@ private: const HttpClientSettings m_ConnectionSettings; }; +struct LatencyTestResult +{ + bool Success = false; + std::string FailureReason; + double LatencySeconds = -1.0; +}; + +LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url); + void httpclient_forcelink(); // internal } // namespace zen diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp index 07fcd62ba..faa85f81b 100644 --- a/src/zenremotestore/builds/buildstoragecache.cpp +++ b/src/zenremotestore/builds/buildstoragecache.cpp @@ -474,7 +474,13 @@ TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const boo HttpClient::Response TestResponse = TestHttpClient.Get("/status/builds"); if (TestResponse.IsSuccess()) { - return {.Success = true}; + LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health"); + + if (!LatencyResult.Success) + { + return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + } + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds}; } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; }; diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 4f1b07c37..5219e86d8 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -484,24 +484,6 @@ private: uint64_t FilteredPerSecond = 0; }; -EPartialBlockRequestMode -PartialBlockRequestModeFromString(const std::string_view ModeString) -{ - switch (HashStringAsLowerDjb2(ModeString)) - { - case HashStringDjb2("false"): - return EPartialBlockRequestMode::Off; - case HashStringDjb2("zencacheonly"): - return EPartialBlockRequestMode::ZenCacheOnly; - case HashStringDjb2("mixed"): - return EPartialBlockRequestMode::Mixed; - case HashStringDjb2("true"): - return EPartialBlockRequestMode::All; - default: - return EPartialBlockRequestMode::Invalid; - } -} - std::filesystem::path ZenStateFilePath(const std::filesystem::path& ZenFolderPath) { @@ -903,7 +885,10 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) { ChunkBlockAnalyser BlockAnalyser(m_LogOutput, m_BlockDescriptions, - ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, .IsVerbose = m_Options.IsVerbose}); + ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, + .IsVerbose = m_Options.IsVerbose, + .HostLatencySec = m_Storage.BuildStorageLatencySec, + .HostHighSpeedLatencySec = m_Storage.CacheLatencySec}); std::vector NeededBlocks = BlockAnalyser.GetNeeded( m_RemoteLookup.ChunkHashToChunkIndex, @@ -1034,25 +1019,29 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) { BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } - else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::All) - { - BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::On); - } else { BlockPartialDownloadModes.reserve(m_BlockDescriptions.size()); for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) { const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); - if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::All) + { + BlockPartialDownloadModes.push_back(BlockExistInCache + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); + } + else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) { - BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::On - : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + BlockPartialDownloadModes.push_back(BlockExistInCache + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) { - BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::On - : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); + BlockPartialDownloadModes.push_back(BlockExistInCache + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); } } } diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp index 36b45e800..b249d7d52 100644 --- a/src/zenremotestore/builds/buildstorageutil.cpp +++ b/src/zenremotestore/builds/buildstorageutil.cpp @@ -63,11 +63,13 @@ ResolveBuildStorage(OperationLogOutput& Output, std::string HostUrl; std::string HostName; + double HostLatencySec = -1.0; std::string CacheUrl; std::string CacheName; bool HostAssumeHttp2 = ClientSettings.AssumeHttp2; bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2; + double CacheLatencySec = -1.0; JupiterServerDiscovery DiscoveryResponse; const std::string_view DiscoveryHost = Host.empty() ? OverrideHost : Host; @@ -98,8 +100,9 @@ ResolveBuildStorage(OperationLogOutput& Output, { ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost); } - HostUrl = OverrideHost; - HostName = GetHostNameFromUrl(OverrideHost); + HostUrl = OverrideHost; + HostName = GetHostNameFromUrl(OverrideHost); + HostLatencySec = TestResult.LatencySeconds; } else { @@ -137,6 +140,7 @@ ResolveBuildStorage(OperationLogOutput& Output, HostUrl = ServerEndpoint.BaseUrl; HostAssumeHttp2 = ServerEndpoint.AssumeHttp2; HostName = ServerEndpoint.Name; + HostLatencySec = TestResult.LatencySeconds; break; } else @@ -183,6 +187,7 @@ ResolveBuildStorage(OperationLogOutput& Output, CacheUrl = CacheEndpoint.BaseUrl; CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2; CacheName = CacheEndpoint.Name; + CacheLatencySec = TestResult.LatencySeconds; break; } } @@ -204,6 +209,7 @@ ResolveBuildStorage(OperationLogOutput& Output, CacheUrl = ZenServerLocalHostUrl; CacheAssumeHttp2 = false; CacheName = "localhost"; + CacheLatencySec = TestResult.LatencySeconds; } } }); @@ -219,8 +225,9 @@ ResolveBuildStorage(OperationLogOutput& Output, if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose); TestResult.Success) { - CacheUrl = ZenCacheHost; - CacheName = GetHostNameFromUrl(ZenCacheHost); + CacheUrl = ZenCacheHost; + CacheName = GetHostNameFromUrl(ZenCacheHost); + CacheLatencySec = TestResult.LatencySeconds; } else { @@ -231,10 +238,12 @@ ResolveBuildStorage(OperationLogOutput& Output, return BuildStorageResolveResult{.HostUrl = HostUrl, .HostName = HostName, .HostAssumeHttp2 = HostAssumeHttp2, + .HostLatencySec = HostLatencySec, .CacheUrl = CacheUrl, .CacheName = CacheName, - .CacheAssumeHttp2 = CacheAssumeHttp2}; + .CacheAssumeHttp2 = CacheAssumeHttp2, + .CacheLatencySec = CacheLatencySec}; } std::vector diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index 06cedae3f..d203e0292 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -597,7 +597,7 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span if (MaybeBlockRanges.has_value()) { - const std::vector& BlockRanges = MaybeBlockRanges.value(); + std::vector BlockRanges = MaybeBlockRanges.value(); ZEN_ASSERT(!BlockRanges.empty()); uint64_t RequestedSize = @@ -606,12 +606,54 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span uint64_t(0), [](uint64_t Current, const BlockRangeDescriptor& Range) { return Current + Range.RangeLength; }); - if ((PartialBlockDownloadMode != EPartialBlockDownloadMode::Exact) && ((RequestedSize * 100) / TotalBlockSize) >= 200) + if (PartialBlockDownloadMode != EPartialBlockDownloadMode::Exact && BlockRanges.size() > 1) + { + // TODO: Once we have support in our http client to request multiple ranges in one request this + // logic would need to change as the per-request overhead would go away + + const double LatencySec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed + ? m_Options.HostHighSpeedLatencySec + : m_Options.HostLatencySec; + if (LatencySec > 0) + { + const uint64_t BytesPerSec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed + ? m_Options.HostHighSpeedBytesPerSec + : m_Options.HostSpeedBytesPerSec; + + const double ExtraRequestTimeSec = (BlockRanges.size() - 1) * LatencySec; + const uint64_t ExtraRequestTimeBytes = uint64_t(ExtraRequestTimeSec * BytesPerSec); + + const uint64_t FullRangeSize = + BlockRanges.back().RangeStart + BlockRanges.back().RangeLength - BlockRanges.front().RangeStart; + + if (ExtraRequestTimeBytes + RequestedSize >= FullRangeSize) + { + BlockRanges = std::vector{MergeBlockRanges(BlockRanges)}; + + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Merging {} chunks ({}) from block {} ({}) to single request (extra bytes {})", + NeededBlock.ChunkIndexes.size(), + NiceBytes(RequestedSize), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + NiceBytes(BlockRanges.front().RangeLength - RequestedSize)); + } + + RequestedSize = BlockRanges.front().RangeLength; + } + } + } + + if ((PartialBlockDownloadMode != EPartialBlockDownloadMode::Exact) && + ((TotalBlockSize - RequestedSize) < (512u * 1024u))) { if (m_Options.IsVerbose) { ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Requesting {} chunks ({}) from block {} ({}) using full block request (extra bytes {})", + "Requesting {} chunks ({}) from block {} ({}) using full block request due to small " + "total slack (extra bytes {})", NeededBlock.ChunkIndexes.size(), NiceBytes(RequestedSize), BlockDescription.BlockHash, @@ -624,19 +666,16 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span { Result.BlockRanges.insert(Result.BlockRanges.end(), BlockRanges.begin(), BlockRanges.end()); - if (RequestedSize > TotalWantedChunksSize) + if (m_Options.IsVerbose) { - if (m_Options.IsVerbose) - { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})", - NeededBlock.ChunkIndexes.size(), - NiceBytes(RequestedSize), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - BlockRanges.size(), - NiceBytes(RequestedSize - TotalWantedChunksSize)); - } + ZEN_OPERATION_LOG_INFO(m_LogOutput, + "Requesting {} chunks ({}) from block {} ({}) using {} requests (extra bytes {})", + NeededBlock.ChunkIndexes.size(), + NiceBytes(RequestedSize), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + BlockRanges.size(), + NiceBytes(RequestedSize - TotalWantedChunksSize)); } } } @@ -786,7 +825,7 @@ ChunkBlockAnalyser::CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std: }; uint64_t -ChunkBlockAnalyser::CalculateNextGap(std::span BlockRanges) +ChunkBlockAnalyser::CalculateNextGap(const uint64_t AlwaysAcceptableGap, std::span BlockRanges) { ZEN_ASSERT(BlockRanges.size() > 1); uint64_t AcceptableGap = (uint64_t)-1; @@ -798,7 +837,7 @@ ChunkBlockAnalyser::CalculateNextGap(std::span Block const uint64_t Gap = NextRange.RangeStart - (Range.RangeStart + Range.RangeLength); AcceptableGap = Min(Gap, AcceptableGap); } - AcceptableGap = RoundUp(AcceptableGap, 16u * 1024u); + AcceptableGap = RoundUp(AcceptableGap, AlwaysAcceptableGap); return AcceptableGap; }; @@ -949,10 +988,12 @@ ChunkBlockAnalyser::CalculateBlockRanges(uint32_t BlockIndex, return MakeOptionalBlockRangeVector(TotalBlockSize, MergedRange); } - std::vector CollapsedBlockRanges = CollapseBlockRanges(16u * 1024u, BlockRanges); + const uint64_t AlwaysAcceptableGap = 4u * 1024u; + + std::vector CollapsedBlockRanges = CollapseBlockRanges(AlwaysAcceptableGap, BlockRanges); while (GetBlockRangeLimitForRange(ForceMergeLimits, TotalBlockSize, CollapsedBlockRanges)) { - CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(CollapsedBlockRanges), CollapsedBlockRanges); + CollapsedBlockRanges = CollapseBlockRanges(CalculateNextGap(AlwaysAcceptableGap, CollapsedBlockRanges), CollapsedBlockRanges); } const std::uint64_t WantedCollapsedSize = diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h index bb5b1c5f4..f25ce5b5e 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h @@ -65,6 +65,7 @@ struct ZenCacheEndpointTestResult { bool Success = false; std::string FailureReason; + double LatencySeconds = -1.0; }; ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 6800444e0..31733569e 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -109,17 +110,6 @@ struct RebuildFolderStateStatistics uint64_t FinalizeTreeElapsedWallTimeUs = 0; }; -enum EPartialBlockRequestMode -{ - Off, - ZenCacheOnly, - Mixed, - All, - Invalid -}; - -EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString); - std::filesystem::path ZenStateFilePath(const std::filesystem::path& ZenFolderPath); std::filesystem::path ZenTempFolderPath(const std::filesystem::path& ZenFolderPath); diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h index ab3037c89..4b85d8f1e 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h @@ -17,10 +17,12 @@ struct BuildStorageResolveResult std::string HostUrl; std::string HostName; bool HostAssumeHttp2 = false; + double HostLatencySec = -1.0; std::string CacheUrl; std::string CacheName; bool CacheAssumeHttp2 = false; + double CacheLatencySec = -1.0; }; enum class ZenCacheResolveMode @@ -54,9 +56,11 @@ struct StorageInstance std::unique_ptr BuildStorageHttp; std::unique_ptr BuildStorage; std::string StorageName; + double BuildStorageLatencySec = -1.0; std::unique_ptr CacheHttp; std::unique_ptr BuildCacheStorage; std::string CacheName; + double CacheLatencySec = -1.0; }; } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index 57710fcf5..5a17ef79c 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -82,8 +82,12 @@ class ChunkBlockAnalyser public: struct Options { - bool IsQuiet = false; - bool IsVerbose = false; + bool IsQuiet = false; + bool IsVerbose = false; + double HostLatencySec = -1.0; + double HostHighSpeedLatencySec = -1.0; + uint64_t HostSpeedBytesPerSec = (1u * 1024u * 1024u * 1024u) / 8u; // 1GBit + uint64_t HostHighSpeedBytesPerSec = (2u * 1024u * 1024u * 1024u) / 8u; // 2GBit }; ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span BlockDescriptions, const Options& Options); @@ -110,7 +114,8 @@ public: { Off, SingleRange, - On, + MultiRange, + MultiRangeHighSpeed, Exact }; @@ -130,14 +135,14 @@ private: uint16_t MaxRangeCount; }; - static constexpr uint16_t FullBlockRangePercentLimit = 95; + static constexpr uint16_t FullBlockRangePercentLimit = 98; static constexpr BlockRangeLimit ForceMergeLimits[] = {{.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1}, - {.SizePercent = 90, .MaxRangeCount = 2}, - {.SizePercent = 85, .MaxRangeCount = 8}, - {.SizePercent = 80, .MaxRangeCount = 16}, - {.SizePercent = 75, .MaxRangeCount = 32}, - {.SizePercent = 70, .MaxRangeCount = 48}, + {.SizePercent = 90, .MaxRangeCount = 4}, + {.SizePercent = 85, .MaxRangeCount = 16}, + {.SizePercent = 80, .MaxRangeCount = 32}, + {.SizePercent = 75, .MaxRangeCount = 48}, + {.SizePercent = 70, .MaxRangeCount = 64}, {.SizePercent = 4, .MaxRangeCount = 82}, {.SizePercent = 0, .MaxRangeCount = 96}}; @@ -149,7 +154,7 @@ private: std::span Ranges); std::vector CollapseBlockRanges(const uint64_t AlwaysAcceptableGap, std::span BlockRanges); - uint64_t CalculateNextGap(std::span BlockRanges); + uint64_t CalculateNextGap(const uint64_t AlwaysAcceptableGap, std::span BlockRanges); std::optional> CalculateBlockRanges(uint32_t BlockIndex, const ChunkBlockDescription& BlockDescription, std::span BlockChunkIndexNeeded, diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h index 432496bc1..7bbf40dfa 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h @@ -28,6 +28,7 @@ struct JupiterEndpointTestResult { bool Success = false; std::string FailureReason; + double LatencySeconds = -1.0; }; JupiterEndpointTestResult TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h index 9693e69cf..6f10ab156 100644 --- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h +++ b/src/zenremotestore/include/zenremotestore/operationlogoutput.h @@ -3,6 +3,7 @@ #pragma once #include +#include namespace zen { @@ -57,9 +58,7 @@ public: virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) = 0; }; -struct LoggerRef; - -OperationLogOutput* CreateStandardLogOutput(LoggerRef& Log); +OperationLogOutput* CreateStandardLogOutput(LoggerRef Log); #define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ do \ diff --git a/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h new file mode 100644 index 000000000..54adea2b2 --- /dev/null +++ b/src/zenremotestore/include/zenremotestore/partialblockrequestmode.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen { + +enum EPartialBlockRequestMode +{ + Off, + ZenCacheOnly, + Mixed, + All, + Invalid +}; + +EPartialBlockRequestMode PartialBlockRequestModeFromString(const std::string_view ModeString); + +} // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h index e8b7c15c0..66dfcc62d 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h @@ -34,6 +34,8 @@ std::shared_ptr CreateJupiterBuildsRemoteStore(LoggerRef bool Quiet, bool Unattended, bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool); + WorkerThreadPool& CacheBackgroundWorkerPool, + double& OutHostLatencySec, + double& OutCacheLatencySec); } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 008f94351..152c02ee2 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -6,6 +6,7 @@ #include #include +#include #include @@ -73,6 +74,16 @@ public: std::vector Blocks; }; + struct GetBlockDescriptionsResult : public Result + { + std::vector Blocks; + }; + + struct AttachmentExistsInCacheResult : public Result + { + std::vector HasBody; + }; + struct RemoteStoreInfo { bool CreateBlocks; @@ -111,10 +122,20 @@ public: virtual FinalizeResult FinalizeContainer(const IoHash& RawHash) = 0; virtual SaveAttachmentsResult SaveAttachments(const std::vector& Payloads) = 0; - virtual LoadContainerResult LoadContainer() = 0; - virtual GetKnownBlocksResult GetKnownBlocks() = 0; - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) = 0; + virtual LoadContainerResult LoadContainer() = 0; + virtual GetKnownBlocksResult GetKnownBlocks() = 0; + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) = 0; + virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) = 0; + + struct AttachmentRange + { + uint64_t Offset = 0; + uint64_t Bytes = (uint64_t)-1; + + inline operator bool() const { return Offset != 0 || Bytes != (uint64_t)-1; } + }; + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) = 0; virtual void Flush() = 0; }; @@ -153,14 +174,15 @@ RemoteProjectStore::LoadContainerResult BuildContainer( class JobContext; -RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog, - const CbObject& ContainerObject, - const std::function RawHashes)>& OnReferencedAttachments, - const std::function& HasAttachment, - const std::function&& Chunks)>& OnNeedBlock, - const std::function& OnNeedAttachment, - const std::function& OnChunkedAttachment, - JobContext* OptionalContext); +RemoteProjectStore::Result SaveOplogContainer( + ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function RawHashes)>& OnReferencedAttachments, + const std::function& HasAttachment, + const std::function&& NeededChunkIndexes)>& OnNeedBlock, + const std::function& OnNeedAttachment, + const std::function& OnChunkedAttachment, + JobContext* OptionalContext); RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, @@ -177,15 +199,18 @@ RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, bool IgnoreMissingAttachments, JobContext* OptionalContext); -RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - JobContext* OptionalContext); +RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + WorkerThreadPool& NetworkWorkerPool, + WorkerThreadPool& WorkerPool, + bool ForceDownload, + bool IgnoreMissingAttachments, + bool CleanOplog, + EPartialBlockRequestMode PartialBlockRequestMode, + double HostLatencySec, + double CacheLatencySec, + JobContext* OptionalContext); std::vector GetBlockHashesFromOplog(CbObjectView ContainerObject); std::vector GetBlocksFromOplog(CbObjectView ContainerObject, std::span IncludeBlockHashes); diff --git a/src/zenremotestore/jupiter/jupiterhost.cpp b/src/zenremotestore/jupiter/jupiterhost.cpp index 7706f00c2..2583cfc84 100644 --- a/src/zenremotestore/jupiter/jupiterhost.cpp +++ b/src/zenremotestore/jupiter/jupiterhost.cpp @@ -59,7 +59,13 @@ TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpClient::Response TestResponse = TestHttpClient.Get("/health/live"); if (TestResponse.IsSuccess()) { - return {.Success = true}; + LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health/ready"); + + if (!LatencyResult.Success) + { + return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + } + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds}; } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; } diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp index 0837ed716..7ed93c947 100644 --- a/src/zenremotestore/operationlogoutput.cpp +++ b/src/zenremotestore/operationlogoutput.cpp @@ -95,7 +95,7 @@ StandardLogOutputProgressBar::Finish() } OperationLogOutput* -CreateStandardLogOutput(LoggerRef& Log) +CreateStandardLogOutput(LoggerRef Log) { return new StandardLogOutput(Log); } diff --git a/src/zenremotestore/partialblockrequestmode.cpp b/src/zenremotestore/partialblockrequestmode.cpp new file mode 100644 index 000000000..b3edf515b --- /dev/null +++ b/src/zenremotestore/partialblockrequestmode.cpp @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include + +namespace zen { + +EPartialBlockRequestMode +PartialBlockRequestModeFromString(const std::string_view ModeString) +{ + switch (HashStringAsLowerDjb2(ModeString)) + { + case HashStringDjb2("false"): + return EPartialBlockRequestMode::Off; + case HashStringDjb2("zencacheonly"): + return EPartialBlockRequestMode::ZenCacheOnly; + case HashStringDjb2("mixed"): + return EPartialBlockRequestMode::Mixed; + case HashStringDjb2("true"): + return EPartialBlockRequestMode::All; + default: + return EPartialBlockRequestMode::Invalid; + } +} + +} // namespace zen diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index a8e883dde..c42373e4d 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -441,7 +441,7 @@ public: catch (const HttpClientError& Ex) { Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -451,7 +451,7 @@ public: catch (const std::exception& Ex) { Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -462,7 +462,94 @@ public: return Result; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + { + std::unique_ptr Output(CreateStandardLogOutput(Log())); + + ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); + + GetBlockDescriptionsResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); + + try + { + Result.Blocks = zen::GetBlockDescriptions(*Output, + *m_BuildStorage, + m_BuildCacheStorage.get(), + m_BuildId, + m_OplogBuildPartId, + BlockHashes, + /*AttemptFallback*/ false, + /*IsQuiet*/ false, + /*IsVerbose)*/ false); + } + catch (const HttpClientError& Ex) + { + Result.ErrorCode = MakeErrorCode(Ex); + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + return Result; + } + + virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override + { + AttachmentExistsInCacheResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); + try + { + const std::vector CacheExistsResult = + m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes); + + if (CacheExistsResult.size() == RawHashes.size()) + { + Result.HasBody.reserve(CacheExistsResult.size()); + for (size_t BlobIndex = 0; BlobIndex < CacheExistsResult.size(); BlobIndex++) + { + Result.HasBody.push_back(CacheExistsResult[BlobIndex].HasBody); + } + } + } + catch (const HttpClientError& Ex) + { + Result.ErrorCode = MakeErrorCode(Ex); + Result.Reason = fmt::format("Remote cache: Failed finding known blobs for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Remote cache: Failed finding known blobs for {}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + Ex.what()); + } + return Result; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override { ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); @@ -474,7 +561,7 @@ public: { if (m_BuildCacheStorage) { - IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash); + IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash, Range.Offset, Range.Bytes); if (CachedBlob) { Result.Bytes = std::move(CachedBlob); @@ -482,20 +569,23 @@ public: } if (!Result.Bytes) { - Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); + Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash, Range.Offset, Range.Bytes); if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - RawHash, - Result.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(Result.Bytes))); + if (!Range) + { + m_BuildCacheStorage->PutBuildBlob(m_BuildId, + RawHash, + Result.Bytes.GetContentType(), + CompositeBuffer(SharedBuffer(Result.Bytes))); + } } } } catch (const HttpClientError& Ex) { Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -505,7 +595,7 @@ public: catch (const std::exception& Ex) { Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Failed listing know blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, @@ -558,7 +648,7 @@ public: for (const IoHash& Hash : AttachmentsLeftToFind) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; @@ -623,7 +713,9 @@ CreateJupiterBuildsRemoteStore(LoggerRef InLog, bool Quiet, bool Unattended, bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool) + WorkerThreadPool& CacheBackgroundWorkerPool, + double& OutHostLatencySec, + double& OutCacheLatencySec) { std::string Host = Options.Host; if (!Host.empty() && Host.find("://"sv) == std::string::npos) @@ -727,6 +819,10 @@ CreateJupiterBuildsRemoteStore(LoggerRef InLog, Options.ForceDisableBlocks, Options.ForceDisableTempBlocks, Options.PopulateCache); + + OutHostLatencySec = ResolveRes.HostLatencySec; + OutCacheLatencySec = ResolveRes.CacheLatencySec; + return RemoteStore; } diff --git a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp index 3a67d3842..ec7fb7bbc 100644 --- a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp @@ -217,7 +217,18 @@ public: return Result; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + { + ZEN_UNUSED(BlockHashes); + return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; + } + + virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override + { + return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override { Stopwatch Timer; LoadAttachmentResult Result; @@ -232,7 +243,14 @@ public: { BasicFile ChunkFile; ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); - Result.Bytes = ChunkFile.ReadAll(); + if (Range) + { + Result.Bytes = ChunkFile.ReadRange(Range.Offset, Range.Bytes); + } + else + { + Result.Bytes = ChunkFile.ReadAll(); + } } AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; @@ -245,7 +263,7 @@ public: LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); if (ChunkResult.ErrorCode) { ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; diff --git a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp index 462de2988..f8179831c 100644 --- a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp @@ -212,7 +212,18 @@ public: return Result; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + { + ZEN_UNUSED(BlockHashes); + return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; + } + + virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override + { + return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override { JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); @@ -227,6 +238,10 @@ public: RawHash, Result.Reason); } + if (!Result.ErrorCode && Range) + { + Result.Bytes = IoBuffer(Result.Bytes, Range.Offset, Range.Bytes); + } return Result; } @@ -235,7 +250,7 @@ public: LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 8be8eb0df..2a9da6f58 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -229,29 +230,60 @@ namespace remotestore_impl { struct DownloadInfo { - uint64_t OplogSizeBytes = 0; - std::atomic AttachmentsDownloaded = 0; - std::atomic AttachmentBlocksDownloaded = 0; - std::atomic AttachmentBytesDownloaded = 0; - std::atomic AttachmentBlockBytesDownloaded = 0; - std::atomic AttachmentsStored = 0; - std::atomic AttachmentBytesStored = 0; - std::atomic_size_t MissingAttachmentCount = 0; + uint64_t OplogSizeBytes = 0; + std::atomic AttachmentsDownloaded = 0; + std::atomic AttachmentBlocksDownloaded = 0; + std::atomic AttachmentBlocksRangesDownloaded = 0; + std::atomic AttachmentBytesDownloaded = 0; + std::atomic AttachmentBlockBytesDownloaded = 0; + std::atomic AttachmentBlockRangeBytesDownloaded = 0; + std::atomic AttachmentsStored = 0; + std::atomic AttachmentBytesStored = 0; + std::atomic_size_t MissingAttachmentCount = 0; }; - void DownloadAndSaveBlockChunks(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - Latch& AttachmentsDownloadLatch, - Latch& AttachmentsWriteLatch, - AsyncRemoteResult& RemoteResult, - DownloadInfo& Info, - Stopwatch& LoadAttachmentsTimer, - std::atomic_uint64_t& DownloadStartMS, - const std::vector& Chunks) + class JobContextLogOutput : public OperationLogOutput + { + public: + JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} + virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override + { + ZEN_UNUSED(LogLevel); + if (m_OptionalContext) + { + fmt::basic_memory_buffer MessageBuffer; + fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); + remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); + } + } + + virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); } + virtual uint32_t GetProgressUpdateDelayMS() override { return 0; } + virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override + { + ZEN_UNUSED(InSubTask); + return nullptr; + } + + private: + JobContext* m_OptionalContext; + }; + + void DownloadAndSaveBlockChunks(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + bool IgnoreMissingAttachments, + JobContext* OptionalContext, + WorkerThreadPool& NetworkWorkerPool, + WorkerThreadPool& WorkerPool, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + ThinChunkBlockDescription&& ThinBlockDescription, + std::vector&& NeededChunkIndexes) { AttachmentsDownloadLatch.AddCount(1); NetworkWorkerPool.ScheduleWork( @@ -261,7 +293,8 @@ namespace remotestore_impl { &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &RemoteResult, - Chunks = Chunks, + ThinBlockDescription = std::move(ThinBlockDescription), + NeededChunkIndexes = std::move(NeededChunkIndexes), &Info, &LoadAttachmentsTimer, &DownloadStartMS, @@ -276,6 +309,13 @@ namespace remotestore_impl { } try { + std::vector Chunks; + Chunks.reserve(NeededChunkIndexes.size()); + for (uint32_t ChunkIndex : NeededChunkIndexes) + { + Chunks.push_back(ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } + uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks); @@ -293,7 +333,12 @@ namespace remotestore_impl { } return; } - Info.AttachmentsDownloaded.fetch_add(Chunks.size()); + Info.AttachmentsDownloaded.fetch_add(Result.Chunks.size()); + for (const auto& It : Result.Chunks) + { + uint64_t ChunkSize = It.second.GetCompressedSize(); + Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); + } ZEN_INFO("Loaded {} bulk attachments in {}", Chunks.size(), NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000))); @@ -320,8 +365,6 @@ namespace remotestore_impl { for (const auto& It : Chunks) { - uint64_t ChunkSize = It.second.GetCompressedSize(); - Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); WriteAttachmentBuffers.push_back(It.second.GetCompressed().Flatten().AsIoBuffer()); WriteRawHashes.push_back(It.first); } @@ -350,28 +393,29 @@ namespace remotestore_impl { catch (const std::exception& Ex) { RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Failed to bulk load {} attachments", Chunks.size()), + fmt::format("Failed to bulk load {} attachments", NeededChunkIndexes.size()), Ex.what()); } }, WorkerThreadPool::EMode::EnableBacklog); }; - void DownloadAndSaveBlock(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - Latch& AttachmentsDownloadLatch, - Latch& AttachmentsWriteLatch, - AsyncRemoteResult& RemoteResult, - DownloadInfo& Info, - Stopwatch& LoadAttachmentsTimer, - std::atomic_uint64_t& DownloadStartMS, - const IoHash& BlockHash, - const std::vector& Chunks, - uint32_t RetriesLeft) + void DownloadAndSaveBlock(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + bool IgnoreMissingAttachments, + JobContext* OptionalContext, + WorkerThreadPool& NetworkWorkerPool, + WorkerThreadPool& WorkerPool, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + const IoHash& BlockHash, + const tsl::robin_map& AllNeededPartialChunkHashesLookup, + std::span> ChunkDownloadedFlags, + uint32_t RetriesLeft) { AttachmentsDownloadLatch.AddCount(1); NetworkWorkerPool.ScheduleWork( @@ -381,7 +425,6 @@ namespace remotestore_impl { &RemoteStore, &NetworkWorkerPool, &WorkerPool, - BlockHash, &RemoteResult, &Info, &LoadAttachmentsTimer, @@ -389,7 +432,9 @@ namespace remotestore_impl { IgnoreMissingAttachments, OptionalContext, RetriesLeft, - Chunks = std::vector(Chunks)]() { + BlockHash = IoHash(BlockHash), + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags]() { ZEN_TRACE_CPU("DownloadBlock"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -401,7 +446,7 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); + RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash, {}); if (BlockResult.ErrorCode) { ReportMessage(OptionalContext, @@ -422,10 +467,10 @@ namespace remotestore_impl { } uint64_t BlockSize = BlockResult.Bytes.GetSize(); Info.AttachmentBlocksDownloaded.fetch_add(1); - ZEN_INFO("Loaded block attachment '{}' in {} ({})", - BlockHash, - NiceTimeSpanMs(static_cast(BlockResult.ElapsedSeconds * 1000)), - NiceBytes(BlockSize)); + ZEN_DEBUG("Loaded block attachment '{}' in {} ({})", + BlockHash, + NiceTimeSpanMs(static_cast(BlockResult.ElapsedSeconds * 1000)), + NiceBytes(BlockSize)); Info.AttachmentBlockBytesDownloaded.fetch_add(BlockSize); AttachmentsWriteLatch.AddCount(1); @@ -436,7 +481,6 @@ namespace remotestore_impl { &RemoteStore, &NetworkWorkerPool, &WorkerPool, - BlockHash, &RemoteResult, &Info, &LoadAttachmentsTimer, @@ -444,8 +488,10 @@ namespace remotestore_impl { IgnoreMissingAttachments, OptionalContext, RetriesLeft, - Chunks = std::move(Chunks), - Bytes = std::move(BlockResult.Bytes)]() { + BlockHash = IoHash(BlockHash), + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + Bytes = std::move(BlockResult.Bytes)]() { auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); if (RemoteResult.IsError()) { @@ -454,9 +500,6 @@ namespace remotestore_impl { try { ZEN_ASSERT(Bytes.Size() > 0); - std::unordered_set WantedChunks; - WantedChunks.reserve(Chunks.size()); - WantedChunks.insert(Chunks.begin(), Chunks.end()); std::vector WriteAttachmentBuffers; std::vector WriteRawHashes; @@ -485,7 +528,8 @@ namespace remotestore_impl { LoadAttachmentsTimer, DownloadStartMS, BlockHash, - std::move(Chunks), + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, RetriesLeft - 1); } ReportMessage( @@ -519,7 +563,8 @@ namespace remotestore_impl { LoadAttachmentsTimer, DownloadStartMS, BlockHash, - std::move(Chunks), + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, RetriesLeft - 1); } ReportMessage(OptionalContext, @@ -546,28 +591,36 @@ namespace remotestore_impl { uint64_t BlockSize = BlockPayload.GetSize(); uint64_t BlockHeaderSize = 0; - bool StoreChunksOK = IterateChunkBlock( - BlockPayload.Flatten(), - [&WantedChunks, &WriteAttachmentBuffers, &WriteRawHashes, &Info, &PotentialSize]( - CompressedBuffer&& Chunk, - const IoHash& AttachmentRawHash) { - if (WantedChunks.contains(AttachmentRawHash)) - { - WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); - IoHash RawHash; - uint64_t RawSize; - ZEN_ASSERT( - CompressedBuffer::ValidateCompressedHeader(WriteAttachmentBuffers.back(), - RawHash, - RawSize, - /*OutOptionalTotalCompressedSize*/ nullptr)); - ZEN_ASSERT(RawHash == AttachmentRawHash); - WriteRawHashes.emplace_back(AttachmentRawHash); - WantedChunks.erase(AttachmentRawHash); - PotentialSize += WriteAttachmentBuffers.back().GetSize(); - } - }, - BlockHeaderSize); + + bool StoreChunksOK = IterateChunkBlock( + BlockPayload.Flatten(), + [&AllNeededPartialChunkHashesLookup, + &ChunkDownloadedFlags, + &WriteAttachmentBuffers, + &WriteRawHashes, + &Info, + &PotentialSize](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(AttachmentRawHash); + if (ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) + { + WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); + IoHash RawHash; + uint64_t RawSize; + ZEN_ASSERT( + CompressedBuffer::ValidateCompressedHeader(WriteAttachmentBuffers.back(), + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)); + ZEN_ASSERT(RawHash == AttachmentRawHash); + WriteRawHashes.emplace_back(AttachmentRawHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } + }, + BlockHeaderSize); if (!StoreChunksOK) { @@ -582,8 +635,6 @@ namespace remotestore_impl { return; } - ZEN_ASSERT(WantedChunks.empty()); - if (!WriteAttachmentBuffers.empty()) { auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); @@ -625,6 +676,293 @@ namespace remotestore_impl { WorkerThreadPool::EMode::EnableBacklog); }; + void DownloadAndSavePartialBlock(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + bool IgnoreMissingAttachments, + JobContext* OptionalContext, + WorkerThreadPool& NetworkWorkerPool, + WorkerThreadPool& WorkerPool, + Latch& AttachmentsDownloadLatch, + Latch& AttachmentsWriteLatch, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + Stopwatch& LoadAttachmentsTimer, + std::atomic_uint64_t& DownloadStartMS, + const ChunkBlockDescription& BlockDescription, + std::span BlockRangeDescriptors, + size_t BlockRangeIndexStart, + size_t BlockRangeCount, + const tsl::robin_map& AllNeededPartialChunkHashesLookup, + std::span> ChunkDownloadedFlags, + uint32_t RetriesLeft) + { + AttachmentsDownloadLatch.AddCount(1); + NetworkWorkerPool.ScheduleWork( + [&AttachmentsDownloadLatch, + &AttachmentsWriteLatch, + &ChunkStore, + &RemoteStore, + &NetworkWorkerPool, + &WorkerPool, + &RemoteResult, + &Info, + &LoadAttachmentsTimer, + &DownloadStartMS, + IgnoreMissingAttachments, + OptionalContext, + RetriesLeft, + BlockDescription, + BlockRangeDescriptors, + BlockRangeIndexStart, + BlockRangeCount, + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags]() { + ZEN_TRACE_CPU("DownloadBlockRanges"); + + auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); + try + { + uint64_t Unset = (std::uint64_t)-1; + DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); + + double DownloadElapsedSeconds = 0; + uint64_t DownloadedBytes = 0; + + for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; + BlockRangeIndex++) + { + if (RemoteResult.IsError()) + { + return; + } + + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRangeDescriptors[BlockRangeIndex]; + + RemoteProjectStore::LoadAttachmentResult BlockResult = + RemoteStore.LoadAttachment(BlockDescription.BlockHash, + {.Offset = BlockRange.RangeStart, .Bytes = BlockRange.RangeLength}); + if (BlockResult.ErrorCode) + { + ReportMessage(OptionalContext, + fmt::format("Failed to download block attachment '{}' range {},{} ({}): {}", + BlockDescription.BlockHash, + BlockRange.RangeStart, + BlockRange.RangeLength, + BlockResult.ErrorCode, + BlockResult.Reason)); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + } + return; + } + if (RemoteResult.IsError()) + { + return; + } + uint64_t BlockPartSize = BlockResult.Bytes.GetSize(); + if (BlockPartSize != BlockRange.RangeLength) + { + std::string ErrorString = + fmt::format("Failed to download block attachment '{}' range {},{}, got {} bytes ({}): {}", + BlockDescription.BlockHash, + BlockRange.RangeStart, + BlockRange.RangeLength, + BlockPartSize, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + + ReportMessage(OptionalContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), + "Mismatching block part range received", + ErrorString); + } + return; + } + Info.AttachmentBlocksRangesDownloaded.fetch_add(1); + + DownloadElapsedSeconds += BlockResult.ElapsedSeconds; + DownloadedBytes += BlockPartSize; + + Info.AttachmentBlockRangeBytesDownloaded.fetch_add(BlockPartSize); + + AttachmentsWriteLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&AttachmentsDownloadLatch, + &AttachmentsWriteLatch, + &ChunkStore, + &RemoteStore, + &NetworkWorkerPool, + &WorkerPool, + &RemoteResult, + &Info, + &LoadAttachmentsTimer, + &DownloadStartMS, + IgnoreMissingAttachments, + OptionalContext, + RetriesLeft, + BlockDescription, + BlockRange, + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + BlockPayload = std::move(BlockResult.Bytes)]() { + auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + try + { + ZEN_ASSERT(BlockPayload.Size() > 0); + std::vector WriteAttachmentBuffers; + std::vector WriteRawHashes; + + uint64_t PotentialSize = 0; + uint64_t UsedSize = 0; + uint64_t BlockPartSize = BlockPayload.GetSize(); + + uint32_t OffsetInBlock = 0; + for (uint32_t ChunkBlockIndex = BlockRange.ChunkBlockIndexStart; + ChunkBlockIndex < BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount; + ChunkBlockIndex++) + { + const uint32_t ChunkCompressedSize = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + + if (auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(ChunkHash); + ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) + { + IoHash VerifyChunkHash; + uint64_t VerifyChunkSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed( + SharedBuffer(IoBuffer(BlockPayload, OffsetInBlock, ChunkCompressedSize)), + VerifyChunkHash, + VerifyChunkSize); + if (!CompressedChunk) + { + std::string ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' is not a valid compressed buffer", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash); + ReportMessage(OptionalContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), + "Malformed chunk block", + ErrorString); + } + continue; + } + if (VerifyChunkHash != ChunkHash) + { + std::string ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching hash, expected {}, got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + ChunkHash, + VerifyChunkHash); + ReportMessage(OptionalContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), + "Malformed chunk block", + ErrorString); + } + continue; + } + if (VerifyChunkSize != BlockDescription.ChunkRawLengths[ChunkBlockIndex]) + { + std::string ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching raw size, expected {}, " + "got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + BlockDescription.ChunkRawLengths[ChunkBlockIndex], + VerifyChunkSize); + ReportMessage(OptionalContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), + "Malformed chunk block", + ErrorString); + } + continue; + } + + WriteAttachmentBuffers.emplace_back(CompressedChunk.GetCompressed().Flatten().AsIoBuffer()); + WriteRawHashes.emplace_back(ChunkHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } + OffsetInBlock += ChunkCompressedSize; + } + + if (!WriteAttachmentBuffers.empty()) + { + auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + for (size_t Index = 0; Index < Results.size(); Index++) + { + const auto& Result = Results[Index]; + if (Result.New) + { + Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); + Info.AttachmentsStored.fetch_add(1); + UsedSize += WriteAttachmentBuffers[Index].GetSize(); + } + } + ZEN_DEBUG("Used {} (matching {}) out of {} for block {} range {}, {} ({} %) (use of matching {}%)", + NiceBytes(UsedSize), + NiceBytes(PotentialSize), + NiceBytes(BlockPartSize), + BlockDescription.BlockHash, + BlockRange.RangeStart, + BlockRange.RangeLength, + (100 * UsedSize) / BlockPartSize, + PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); + } + } + catch (const std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), + fmt::format("Failed save block attachment {} range {}, {}", + BlockDescription.BlockHash, + BlockRange.RangeStart, + BlockRange.RangeLength), + Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})", + BlockRangeCount, + BlockDescription.BlockHash, + NiceTimeSpanMs(static_cast(DownloadElapsedSeconds * 1000)), + NiceBytes(DownloadedBytes)); + } + catch (const std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), + fmt::format("Failed to download block attachment {} ranges", BlockDescription.BlockHash), + Ex.what()); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + }; + void DownloadAndSaveAttachment(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, bool IgnoreMissingAttachments, @@ -664,7 +1002,7 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); + RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash, {}); if (AttachmentResult.ErrorCode) { ReportMessage(OptionalContext, @@ -680,10 +1018,10 @@ namespace remotestore_impl { return; } uint64_t AttachmentSize = AttachmentResult.Bytes.GetSize(); - ZEN_INFO("Loaded large attachment '{}' in {} ({})", - RawHash, - NiceTimeSpanMs(static_cast(AttachmentResult.ElapsedSeconds * 1000)), - NiceBytes(AttachmentSize)); + ZEN_DEBUG("Loaded large attachment '{}' in {} ({})", + RawHash, + NiceTimeSpanMs(static_cast(AttachmentResult.ElapsedSeconds * 1000)), + NiceBytes(AttachmentSize)); Info.AttachmentsDownloaded.fetch_add(1); if (RemoteResult.IsError()) { @@ -1224,35 +1562,7 @@ BuildContainer(CidStore& ChunkStore, { using namespace std::literals; - class JobContextLogOutput : public OperationLogOutput - { - public: - JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override - { - ZEN_UNUSED(LogLevel); - if (m_OptionalContext) - { - fmt::basic_memory_buffer MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); - remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); - } - } - - virtual void SetLogOperationName(std::string_view Name) override { ZEN_UNUSED(Name); } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { ZEN_UNUSED(StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() override { return 0; } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override - { - ZEN_UNUSED(InSubTask); - return nullptr; - } - - private: - JobContext* m_OptionalContext; - }; - - std::unique_ptr LogOutput(std::make_unique(OptionalContext)); + std::unique_ptr LogOutput(std::make_unique(OptionalContext)); size_t OpCount = 0; @@ -2768,14 +3078,15 @@ SaveOplog(CidStore& ChunkStore, }; RemoteProjectStore::Result -ParseOplogContainer(const CbObject& ContainerObject, - const std::function RawHashes)>& OnReferencedAttachments, - const std::function& HasAttachment, - const std::function&& Chunks)>& OnNeedBlock, - const std::function& OnNeedAttachment, - const std::function& OnChunkedAttachment, - CbObject& OutOplogSection, - JobContext* OptionalContext) +ParseOplogContainer( + const CbObject& ContainerObject, + const std::function RawHashes)>& OnReferencedAttachments, + const std::function& HasAttachment, + const std::function&& NeededChunkIndexes)>& OnNeedBlock, + const std::function& OnNeedAttachment, + const std::function& OnChunkedAttachment, + CbObject& OutOplogSection, + JobContext* OptionalContext) { using namespace std::literals; @@ -2801,12 +3112,12 @@ ParseOplogContainer(const CbObject& ContainerObject, "Section has unexpected data type", "Failed to save oplog container"}; } - std::unordered_set OpsAttachments; + std::unordered_set NeededAttachments; { CbArrayView OpsArray = OutOplogSection["ops"sv].AsArrayView(); for (CbFieldView OpEntry : OpsArray) { - OpEntry.IterateAttachments([&](CbFieldView FieldView) { OpsAttachments.insert(FieldView.AsAttachment()); }); + OpEntry.IterateAttachments([&](CbFieldView FieldView) { NeededAttachments.insert(FieldView.AsAttachment()); }); if (remotestore_impl::IsCancelled(OptionalContext)) { return RemoteProjectStore::Result{.ErrorCode = gsl::narrow(HttpResponseCode::OK), @@ -2816,7 +3127,7 @@ ParseOplogContainer(const CbObject& ContainerObject, } } { - std::vector ReferencedAttachments(OpsAttachments.begin(), OpsAttachments.end()); + std::vector ReferencedAttachments(NeededAttachments.begin(), NeededAttachments.end()); OnReferencedAttachments(ReferencedAttachments); } @@ -2827,24 +3138,27 @@ ParseOplogContainer(const CbObject& ContainerObject, .Reason = "Operation cancelled"}; } - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", OpsAttachments.size())); + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Oplog references {} attachments", NeededAttachments.size())); CbArrayView ChunkedFilesArray = ContainerObject["chunkedfiles"sv].AsArrayView(); for (CbFieldView ChunkedFileField : ChunkedFilesArray) { CbObjectView ChunkedFileView = ChunkedFileField.AsObjectView(); IoHash RawHash = ChunkedFileView["rawhash"sv].AsHash(); - if (OpsAttachments.contains(RawHash) && (!HasAttachment(RawHash))) + if (NeededAttachments.erase(RawHash) == 1) { - ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView); - - OnReferencedAttachments(Chunked.ChunkHashes); - OpsAttachments.insert(Chunked.ChunkHashes.begin(), Chunked.ChunkHashes.end()); - OnChunkedAttachment(Chunked); - ZEN_INFO("Requesting chunked attachment '{}' ({}) built from {} chunks", - Chunked.RawHash, - NiceBytes(Chunked.RawSize), - Chunked.ChunkHashes.size()); + if (!HasAttachment(RawHash)) + { + ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView); + + OnReferencedAttachments(Chunked.ChunkHashes); + NeededAttachments.insert(Chunked.ChunkHashes.begin(), Chunked.ChunkHashes.end()); + OnChunkedAttachment(Chunked); + ZEN_INFO("Requesting chunked attachment '{}' ({}) built from {} chunks", + Chunked.RawHash, + NiceBytes(Chunked.RawSize), + Chunked.ChunkHashes.size()); + } } if (remotestore_impl::IsCancelled(OptionalContext)) { @@ -2854,6 +3168,8 @@ ParseOplogContainer(const CbObject& ContainerObject, } } + std::vector ThinBlocksDescriptions; + size_t NeedBlockCount = 0; CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView(); for (CbFieldView BlockField : BlocksArray) @@ -2863,45 +3179,38 @@ ParseOplogContainer(const CbObject& ContainerObject, CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView(); - std::vector NeededChunks; - NeededChunks.reserve(ChunksArray.Num()); - if (BlockHash == IoHash::Zero) + std::vector ChunkHashes; + ChunkHashes.reserve(ChunksArray.Num()); + for (CbFieldView ChunkField : ChunksArray) { - for (CbFieldView ChunkField : ChunksArray) - { - IoHash ChunkHash = ChunkField.AsBinaryAttachment(); - if (OpsAttachments.erase(ChunkHash) == 1) - { - if (!HasAttachment(ChunkHash)) - { - NeededChunks.emplace_back(ChunkHash); - } - } - } + ChunkHashes.push_back(ChunkField.AsHash()); } - else + ThinBlocksDescriptions.push_back(ThinChunkBlockDescription{.BlockHash = BlockHash, .ChunkRawHashes = std::move(ChunkHashes)}); + } + + for (ThinChunkBlockDescription& ThinBlockDescription : ThinBlocksDescriptions) + { + std::vector NeededBlockChunkIndexes; + for (uint32_t ChunkIndex = 0; ChunkIndex < ThinBlockDescription.ChunkRawHashes.size(); ChunkIndex++) { - for (CbFieldView ChunkField : ChunksArray) + const IoHash& ChunkHash = ThinBlockDescription.ChunkRawHashes[ChunkIndex]; + if (NeededAttachments.erase(ChunkHash) == 1) { - const IoHash ChunkHash = ChunkField.AsHash(); - if (OpsAttachments.erase(ChunkHash) == 1) + if (!HasAttachment(ChunkHash)) { - if (!HasAttachment(ChunkHash)) - { - NeededChunks.emplace_back(ChunkHash); - } + NeededBlockChunkIndexes.push_back(ChunkIndex); } } } - - if (!NeededChunks.empty()) + if (!NeededBlockChunkIndexes.empty()) { - OnNeedBlock(BlockHash, std::move(NeededChunks)); - if (BlockHash != IoHash::Zero) + if (ThinBlockDescription.BlockHash != IoHash::Zero) { NeedBlockCount++; } + OnNeedBlock(std::move(ThinBlockDescription), std::move(NeededBlockChunkIndexes)); } + if (remotestore_impl::IsCancelled(OptionalContext)) { return RemoteProjectStore::Result{.ErrorCode = gsl::narrow(HttpResponseCode::OK), @@ -2909,6 +3218,7 @@ ParseOplogContainer(const CbObject& ContainerObject, .Reason = "Operation cancelled"}; } } + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Requesting {} of {} attachment blocks", NeedBlockCount, BlocksArray.Num())); @@ -2918,7 +3228,7 @@ ParseOplogContainer(const CbObject& ContainerObject, { IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment(); - if (OpsAttachments.erase(AttachmentHash) == 1) + if (NeededAttachments.erase(AttachmentHash) == 1) { if (!HasAttachment(AttachmentHash)) { @@ -2941,14 +3251,15 @@ ParseOplogContainer(const CbObject& ContainerObject, } RemoteProjectStore::Result -SaveOplogContainer(ProjectStore::Oplog& Oplog, - const CbObject& ContainerObject, - const std::function RawHashes)>& OnReferencedAttachments, - const std::function& HasAttachment, - const std::function&& Chunks)>& OnNeedBlock, - const std::function& OnNeedAttachment, - const std::function& OnChunkedAttachment, - JobContext* OptionalContext) +SaveOplogContainer( + ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function RawHashes)>& OnReferencedAttachments, + const std::function& HasAttachment, + const std::function&& NeededChunkIndexes)>& OnNeedBlock, + const std::function& OnNeedAttachment, + const std::function& OnChunkedAttachment, + JobContext* OptionalContext) { using namespace std::literals; @@ -2972,18 +3283,23 @@ SaveOplogContainer(ProjectStore::Oplog& Oplog, } RemoteProjectStore::Result -LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - JobContext* OptionalContext) +LoadOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + WorkerThreadPool& NetworkWorkerPool, + WorkerThreadPool& WorkerPool, + bool ForceDownload, + bool IgnoreMissingAttachments, + bool CleanOplog, + EPartialBlockRequestMode PartialBlockRequestMode, + double HostLatencySec, + double CacheLatencySec, + JobContext* OptionalContext) { using namespace std::literals; + std::unique_ptr LogOutput(std::make_unique(OptionalContext)); + remotestore_impl::DownloadInfo Info; Stopwatch Timer; @@ -3035,6 +3351,14 @@ LoadOplog(CidStore& ChunkStore, return false; }; + struct NeededBlockDownload + { + ThinChunkBlockDescription ThinBlockDescription; + std::vector NeededChunkIndexes; + }; + + std::vector NeededBlockDownloads; + auto OnNeedBlock = [&RemoteStore, &ChunkStore, &NetworkWorkerPool, @@ -3047,8 +3371,9 @@ LoadOplog(CidStore& ChunkStore, &Info, &LoadAttachmentsTimer, &DownloadStartMS, + &NeededBlockDownloads, IgnoreMissingAttachments, - OptionalContext](const IoHash& BlockHash, std::vector&& Chunks) { + OptionalContext](ThinChunkBlockDescription&& ThinBlockDescription, std::vector&& NeededChunkIndexes) { if (RemoteResult.IsError()) { return; @@ -3056,7 +3381,7 @@ LoadOplog(CidStore& ChunkStore, BlockCountToDownload++; AttachmentCount.fetch_add(1); - if (BlockHash == IoHash::Zero) + if (ThinBlockDescription.BlockHash == IoHash::Zero) { DownloadAndSaveBlockChunks(ChunkStore, RemoteStore, @@ -3070,25 +3395,13 @@ LoadOplog(CidStore& ChunkStore, Info, LoadAttachmentsTimer, DownloadStartMS, - Chunks); + std::move(ThinBlockDescription), + std::move(NeededChunkIndexes)); } else { - DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, - AttachmentsDownloadLatch, - AttachmentsWriteLatch, - RemoteResult, - Info, - LoadAttachmentsTimer, - DownloadStartMS, - BlockHash, - Chunks, - 3); + NeededBlockDownloads.push_back(NeededBlockDownload{.ThinBlockDescription = std::move(ThinBlockDescription), + .NeededChunkIndexes = std::move(NeededChunkIndexes)}); } }; @@ -3132,12 +3445,7 @@ LoadOplog(CidStore& ChunkStore, }; std::vector FilesToDechunk; - auto OnChunkedAttachment = [&Oplog, &ChunkStore, &FilesToDechunk, ForceDownload](const ChunkedInfo& Chunked) { - if (ForceDownload || !ChunkStore.ContainsChunk(Chunked.RawHash)) - { - FilesToDechunk.push_back(Chunked); - } - }; + auto OnChunkedAttachment = [&FilesToDechunk](const ChunkedInfo& Chunked) { FilesToDechunk.push_back(Chunked); }; auto OnReferencedAttachments = [&Oplog](std::span RawHashes) { Oplog.CaptureAddedAttachments(RawHashes); }; @@ -3165,6 +3473,185 @@ LoadOplog(CidStore& ChunkStore, BlockCountToDownload, FilesToDechunk.size())); + std::vector BlockHashes; + std::vector AllNeededChunkHashes; + BlockHashes.reserve(NeededBlockDownloads.size()); + for (const NeededBlockDownload& BlockDownload : NeededBlockDownloads) + { + BlockHashes.push_back(BlockDownload.ThinBlockDescription.BlockHash); + for (uint32_t ChunkIndex : BlockDownload.NeededChunkIndexes) + { + AllNeededChunkHashes.push_back(BlockDownload.ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } + } + + tsl::robin_map AllNeededPartialChunkHashesLookup = BuildHashLookup(AllNeededChunkHashes); + std::vector> ChunkDownloadedFlags(AllNeededChunkHashes.size()); + std::vector DownloadedViaLegacyChunkFlag(AllNeededChunkHashes.size(), false); + ChunkBlockAnalyser::BlockResult PartialBlocksResult; + + RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions = RemoteStore.GetBlockDescriptions(BlockHashes); + std::vector BlocksWithDescription; + BlocksWithDescription.reserve(BlockDescriptions.Blocks.size()); + for (const ChunkBlockDescription& BlockDescription : BlockDescriptions.Blocks) + { + BlocksWithDescription.push_back(BlockDescription.BlockHash); + } + { + auto WantIt = NeededBlockDownloads.begin(); + auto FindIt = BlockDescriptions.Blocks.begin(); + while (WantIt != NeededBlockDownloads.end()) + { + if (FindIt == BlockDescriptions.Blocks.end()) + { + // Fall back to full download as we can't get enough information about the block + DownloadAndSaveBlock(ChunkStore, + RemoteStore, + IgnoreMissingAttachments, + OptionalContext, + NetworkWorkerPool, + WorkerPool, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + WantIt->ThinBlockDescription.BlockHash, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + 3); + for (uint32_t BlockChunkIndex : WantIt->NeededChunkIndexes) + { + const IoHash& ChunkHash = WantIt->ThinBlockDescription.ChunkRawHashes[BlockChunkIndex]; + auto It = AllNeededPartialChunkHashesLookup.find(ChunkHash); + ZEN_ASSERT(It != AllNeededPartialChunkHashesLookup.end()); + uint32_t ChunkIndex = It->second; + DownloadedViaLegacyChunkFlag[ChunkIndex] = true; + } + WantIt++; + } + else if (WantIt->ThinBlockDescription.BlockHash == FindIt->BlockHash) + { + // Found + FindIt++; + WantIt++; + } + else + { + // Not a requested block? + ZEN_ASSERT(false); + } + } + } + if (!AllNeededChunkHashes.empty()) + { + std::vector PartialBlockDownloadModes; + + if (PartialBlockRequestMode == EPartialBlockRequestMode::Off) + { + PartialBlockDownloadModes.resize(BlocksWithDescription.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else + { + RemoteProjectStore::AttachmentExistsInCacheResult CacheExistsResult = + RemoteStore.AttachmentExistsInCache(BlocksWithDescription); + if (CacheExistsResult.ErrorCode != 0 || CacheExistsResult.HasBody.size() != BlocksWithDescription.size()) + { + CacheExistsResult.HasBody.resize(BlocksWithDescription.size(), false); + } + + PartialBlockDownloadModes.reserve(BlocksWithDescription.size()); + + for (bool ExistsInCache : CacheExistsResult.HasBody) + { + if (PartialBlockRequestMode == EPartialBlockRequestMode::All) + { + PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); + } + else if (PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) + { + PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + } + else if (PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) + { + PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); + } + } + } + + ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size()); + + ChunkBlockAnalyser PartialAnalyser(*LogOutput, + BlockDescriptions.Blocks, + ChunkBlockAnalyser::Options{.IsQuiet = false, + .IsVerbose = false, + .HostLatencySec = HostLatencySec, + .HostHighSpeedLatencySec = CacheLatencySec}); + + std::vector NeededBlocks = + PartialAnalyser.GetNeeded(AllNeededPartialChunkHashesLookup, + [&](uint32_t ChunkIndex) { return !DownloadedViaLegacyChunkFlag[ChunkIndex]; }); + + PartialBlocksResult = PartialAnalyser.CalculatePartialBlockDownloads(NeededBlocks, PartialBlockDownloadModes); + for (uint32_t FullBlockIndex : PartialBlocksResult.FullBlockIndexes) + { + DownloadAndSaveBlock(ChunkStore, + RemoteStore, + IgnoreMissingAttachments, + OptionalContext, + NetworkWorkerPool, + WorkerPool, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescriptions.Blocks[FullBlockIndex].BlockHash, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + 3); + } + + for (size_t BlockRangeIndex = 0; BlockRangeIndex < PartialBlocksResult.BlockRanges.size();) + { + size_t RangeCount = 1; + size_t RangesLeft = PartialBlocksResult.BlockRanges.size() - BlockRangeIndex; + const ChunkBlockAnalyser::BlockRangeDescriptor& CurrentBlockRange = PartialBlocksResult.BlockRanges[BlockRangeIndex]; + while (RangeCount < RangesLeft && + CurrentBlockRange.BlockIndex == PartialBlocksResult.BlockRanges[BlockRangeIndex + RangeCount].BlockIndex) + { + RangeCount++; + } + + DownloadAndSavePartialBlock(ChunkStore, + RemoteStore, + IgnoreMissingAttachments, + OptionalContext, + NetworkWorkerPool, + WorkerPool, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescriptions.Blocks[CurrentBlockRange.BlockIndex], + PartialBlocksResult.BlockRanges, + BlockRangeIndex, + RangeCount, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + 3); + + BlockRangeIndex += RangeCount; + } + } + AttachmentsDownloadLatch.CountDown(); while (!AttachmentsDownloadLatch.Wait(1000)) { @@ -3478,21 +3965,30 @@ LoadOplog(CidStore& ChunkStore, } } - remotestore_impl::ReportMessage( - OptionalContext, - fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}), Stored: {} ({}), Missing: {} {}", - RemoteStoreInfo.ContainerName, - Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE", - NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), - NiceBytes(Info.OplogSizeBytes), - Info.AttachmentBlocksDownloaded.load(), - NiceBytes(Info.AttachmentBlockBytesDownloaded.load()), - Info.AttachmentsDownloaded.load(), - NiceBytes(Info.AttachmentBytesDownloaded.load()), - Info.AttachmentsStored.load(), - NiceBytes(Info.AttachmentBytesStored.load()), - Info.MissingAttachmentCount.load(), - remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); + uint64_t TotalDownloads = + 1 + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); + uint64_t TotalBytesDownloaded = Info.OplogSizeBytes + Info.AttachmentBlockBytesDownloaded.load() + + Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load(); + + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), BlockRanges: {} ({}), Attachments: {} " + "({}), Total: {} ({}), Stored: {} ({}), Missing: {} {}", + RemoteStoreInfo.ContainerName, + Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), + NiceBytes(Info.OplogSizeBytes), + Info.AttachmentBlocksDownloaded.load(), + NiceBytes(Info.AttachmentBlockBytesDownloaded.load()), + Info.AttachmentBlocksRangesDownloaded.load(), + NiceBytes(Info.AttachmentBlockRangeBytesDownloaded.load()), + Info.AttachmentsDownloaded.load(), + NiceBytes(Info.AttachmentBytesDownloaded.load()), + TotalDownloads, + NiceBytes(TotalBytesDownloaded), + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Info.MissingAttachmentCount.load(), + remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); return Result; } @@ -3697,6 +4193,9 @@ TEST_CASE_TEMPLATE("project.store.export", /*Force*/ false, /*IgnoreMissingAttachments*/ false, /*CleanOplog*/ false, + EPartialBlockRequestMode::Mixed, + /*HostLatencySec*/ -1.0, + /*CacheLatencySec*/ -1.0, nullptr); CHECK(ImportResult.ErrorCode == 0); @@ -3708,6 +4207,9 @@ TEST_CASE_TEMPLATE("project.store.export", /*Force*/ true, /*IgnoreMissingAttachments*/ false, /*CleanOplog*/ false, + EPartialBlockRequestMode::Mixed, + /*HostLatencySec*/ -1.0, + /*CacheLatencySec*/ -1.0, nullptr); CHECK(ImportForceResult.ErrorCode == 0); @@ -3719,6 +4221,9 @@ TEST_CASE_TEMPLATE("project.store.export", /*Force*/ false, /*IgnoreMissingAttachments*/ false, /*CleanOplog*/ true, + EPartialBlockRequestMode::Mixed, + /*HostLatencySec*/ -1.0, + /*CacheLatencySec*/ -1.0, nullptr); CHECK(ImportCleanResult.ErrorCode == 0); @@ -3730,6 +4235,9 @@ TEST_CASE_TEMPLATE("project.store.export", /*Force*/ true, /*IgnoreMissingAttachments*/ false, /*CleanOplog*/ true, + EPartialBlockRequestMode::Mixed, + /*HostLatencySec*/ -1.0, + /*CacheLatencySec*/ -1.0, nullptr); CHECK(ImportForceCleanResult.ErrorCode == 0); } diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp index ab82edbef..b4c1156ac 100644 --- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp @@ -249,7 +249,18 @@ public: return GetKnownBlocksResult{{.ErrorCode = static_cast(HttpResponseCode::NoContent)}}; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + { + ZEN_UNUSED(BlockHashes); + return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; + } + + virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override + { + return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override { std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); HttpClient::Response Response = @@ -257,12 +268,7 @@ public: AddStats(Response); LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)}; - if (!Result.ErrorCode) - { - Result.Bytes = Response.ResponsePayload; - Result.Bytes.MakeOwned(); - } - if (!Result.ErrorCode) + if (Result.ErrorCode) { Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", m_ProjectStoreUrl, @@ -271,6 +277,15 @@ public: RawHash, Result.Reason); } + if (!Result.ErrorCode && Range) + { + Result.Bytes = IoBuffer(Response.ResponsePayload, Range.Offset, Range.Bytes); + } + else + { + Result.Bytes = Response.ResponsePayload; + } + Result.Bytes.MakeOwned(); return Result; } diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index 416e2ed69..2b5474d00 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -244,6 +244,8 @@ namespace { { std::shared_ptr Store; std::string Description; + double HostLatencySec = -1.0; + double CacheLatencySec = -1.0; }; CreateRemoteStoreResult CreateRemoteStore(LoggerRef InLog, @@ -261,6 +263,8 @@ namespace { using namespace std::literals; std::shared_ptr RemoteStore; + double HostLatencySec = -1.0; + double CacheLatencySec = -1.0; if (CbObjectView File = Params["file"sv].AsObjectView(); File) { @@ -495,7 +499,9 @@ namespace { /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true, - GetTinyWorkerPool(EWorkloadType::Background)); + GetTinyWorkerPool(EWorkloadType::Background), + HostLatencySec, + CacheLatencySec); } if (!RemoteStore) @@ -503,7 +509,10 @@ namespace { return {nullptr, "Unknown remote store type"}; } - return {std::move(RemoteStore), ""}; + return CreateRemoteStoreResult{.Store = std::move(RemoteStore), + .Description = "", + .HostLatencySec = HostLatencySec, + .CacheLatencySec = CacheLatencySec}; } std::pair ConvertResult(const RemoteProjectStore::Result& Result) @@ -2356,15 +2365,19 @@ HttpProjectService::HandleOplogSaveRequest(HttpRouterRequest& Req) tsl::robin_set Attachments; auto HasAttachment = [this](const IoHash& RawHash) { return m_CidStore.ContainsChunk(RawHash); }; - auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector&& ChunkHashes) { + auto OnNeedBlock = [&AttachmentsLock, &Attachments](ThinChunkBlockDescription&& ThinBlockDescription, + std::vector&& NeededChunkIndexes) { RwLock::ExclusiveLockScope _(AttachmentsLock); - if (BlockHash != IoHash::Zero) + if (ThinBlockDescription.BlockHash != IoHash::Zero) { - Attachments.insert(BlockHash); + Attachments.insert(ThinBlockDescription.BlockHash); } else { - Attachments.insert(ChunkHashes.begin(), ChunkHashes.end()); + for (uint32_t ChunkIndex : NeededChunkIndexes) + { + Attachments.insert(ThinBlockDescription.ChunkRawHashes[ChunkIndex]); + } } }; auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) { @@ -2663,6 +2676,8 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) bool CleanOplog = Params["clean"].AsBool(false); bool BoostWorkerCount = Params["boostworkercount"].AsBool(false); bool BoostWorkerMemory = Params["boostworkermemory"sv].AsBool(false); + EPartialBlockRequestMode PartialBlockRequestMode = + PartialBlockRequestModeFromString(Params["partialblockrequestmode"sv].AsString("true")); CreateRemoteStoreResult RemoteStoreResult = CreateRemoteStore(Log(), Params, @@ -2688,6 +2703,9 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) Force, IgnoreMissingAttachments, CleanOplog, + PartialBlockRequestMode, + HostLatencySec = RemoteStoreResult.HostLatencySec, + CacheLatencySec = RemoteStoreResult.CacheLatencySec, BoostWorkerCount](JobContext& Context) { Context.ReportMessage(fmt::format("Loading oplog '{}/{}' from {}", Oplog->GetOuterProjectIdentifier(), @@ -2709,6 +2727,9 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) Force, IgnoreMissingAttachments, CleanOplog, + PartialBlockRequestMode, + HostLatencySec, + CacheLatencySec, &Context); auto Response = ConvertResult(Result); ZEN_INFO("LoadOplog: Status: {} '{}'", ToString(Response.first), Response.second); -- cgit v1.2.3 From fb19c8a86e89762ea89df3b361494a055680b432 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Tue, 24 Feb 2026 22:24:11 +0100 Subject: Fix zencore bugs and propagate content type through IoBufferBuilder (#783) - Add missing includes in hashutils.h (``, ``) - Add `ZenContentType` parameter to all `IoBufferBuilder` factory methods so content type is set at buffer creation time - Fix null dereference in `SharedBuffer::GetFileReference()` when buffer is null - Fix out-of-bounds read in trace command-line argument parsing when arg length exactly matches option length - Add unit tests for 32-bit `CountLeadingZeros` --- src/zencore/include/zencore/hashutils.h | 3 +++ src/zencore/include/zencore/iobuffer.h | 37 ++++++++++++++++++++++-------- src/zencore/include/zencore/sharedbuffer.h | 13 ++++++----- src/zencore/intmath.cpp | 6 +++++ src/zencore/iobuffer.cpp | 20 +++++++++------- src/zencore/trace.cpp | 13 ++++++++--- 6 files changed, 65 insertions(+), 27 deletions(-) (limited to 'src') diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h index 4e877e219..6b9902b3a 100644 --- a/src/zencore/include/zencore/hashutils.h +++ b/src/zencore/include/zencore/hashutils.h @@ -2,6 +2,9 @@ #pragma once +#include +#include + namespace zen { template diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h index 182768ff6..82c201edd 100644 --- a/src/zencore/include/zencore/iobuffer.h +++ b/src/zencore/include/zencore/iobuffer.h @@ -426,22 +426,39 @@ private: class IoBufferBuilder { public: - static IoBuffer MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset = 0, uint64_t Size = ~0ull); - static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName); - static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull); - /** Make sure buffer data is memory resident, but avoid memory mapping data from files - */ - static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer); - inline static IoBuffer MakeFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize()); } - inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz) + static IoBuffer MakeFromFile(const std::filesystem::path& FileName, + uint64_t Offset = 0, + uint64_t Size = ~0ull, + ZenContentType ContentType = ZenContentType::kBinary); + static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType = ZenContentType::kBinary); + static IoBuffer MakeFromFileHandle(void* FileHandle, + uint64_t Offset = 0, + uint64_t Size = ~0ull, + ZenContentType ContentType = ZenContentType::kBinary); + inline static IoBuffer MakeFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary) + { + IoBuffer NewBuffer(IoBuffer::Wrap, Memory.GetData(), Memory.GetSize()); + NewBuffer.SetContentType(ContentType); + return NewBuffer; + } + inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz, ZenContentType ContentType = ZenContentType::kBinary) { if (Sz) { - return IoBuffer(IoBuffer::Clone, Ptr, Sz); + IoBuffer NewBuffer(IoBuffer::Clone, Ptr, Sz); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } return {}; } - inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize()); } + inline static IoBuffer MakeCloneFromMemory(MemoryView Memory, ZenContentType ContentType = ZenContentType::kBinary) + { + return MakeCloneFromMemory(Memory.GetData(), Memory.GetSize(), ContentType); + } + + /** Make sure buffer data is memory resident, but avoid memory mapping data from files + */ + static IoBuffer ReadFromFileMaybe(const IoBuffer& InBuffer); }; void iobuffer_forcelink(); diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h index c57e9f568..3d4c19282 100644 --- a/src/zencore/include/zencore/sharedbuffer.h +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -116,14 +116,15 @@ public: inline void Reset() { m_Buffer = nullptr; } inline bool GetFileReference(IoBufferFileReference& OutRef) const { - if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore()) + if (!IsNull()) { - return Core->GetFileReference(OutRef); - } - else - { - return false; + if (const IoBufferExtendedCore* Core = m_Buffer->ExtendedCore()) + { + return Core->GetFileReference(OutRef); + } } + + return false; } [[nodiscard]] MemoryView GetView() const diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp index 5a686dc8e..32f82b486 100644 --- a/src/zencore/intmath.cpp +++ b/src/zencore/intmath.cpp @@ -43,6 +43,12 @@ TEST_CASE("intmath") CHECK(FloorLog2_64(0x0000'0001'0000'0000ull) == 32); CHECK(FloorLog2_64(0x8000'0000'0000'0000ull) == 63); + CHECK(CountLeadingZeros(0x8000'0000u) == 0); + CHECK(CountLeadingZeros(0x0000'0000u) == 32); + CHECK(CountLeadingZeros(0x0000'0001u) == 31); + CHECK(CountLeadingZeros(0x0000'8000u) == 16); + CHECK(CountLeadingZeros(0x0001'0000u) == 15); + CHECK(CountLeadingZeros64(0x8000'0000'0000'0000ull) == 0); CHECK(CountLeadingZeros64(0x0000'0000'0000'0000ull) == 64); CHECK(CountLeadingZeros64(0x0000'0000'0000'0001ull) == 63); diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index be9b39e7a..1c31d6620 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -592,15 +592,17 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer) } IoBuffer -IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size) +IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFileHandle"); - return IoBuffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size); + IoBuffer Buffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size); + Buffer.SetContentType(ContentType); + return Buffer; } IoBuffer -IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size) +IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromFile"); @@ -632,8 +634,6 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of FileSize = Stat.st_size; #endif // ZEN_PLATFORM_WINDOWS - // TODO: should validate that offset is in range - if (Size == ~0ull) { Size = FileSize - Offset; @@ -652,7 +652,9 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of #if ZEN_PLATFORM_WINDOWS void* Fd = DataFile.Detach(); #endif - return IoBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize); + IoBuffer NewBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } #if !ZEN_PLATFORM_WINDOWS @@ -664,7 +666,7 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of } IoBuffer -IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) +IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName, ZenContentType ContentType) { ZEN_TRACE_CPU("IoBufferBuilder::MakeFromTemporaryFile"); @@ -703,7 +705,9 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) Handle = (void*)uintptr_t(Fd); #endif // ZEN_PLATFORM_WINDOWS - return IoBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); + IoBuffer NewBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); + NewBuffer.SetContentType(ContentType); + return NewBuffer; } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp index 87035554f..a026974c0 100644 --- a/src/zencore/trace.cpp +++ b/src/zencore/trace.cpp @@ -165,10 +165,17 @@ GetTraceOptionsFromCommandline(TraceOptions& OutOptions) auto MatchesArg = [](std::string_view Option, std::string_view Arg) -> std::optional { if (Arg.starts_with(Option)) { - std::string_view::value_type DelimChar = Arg[Option.length()]; - if (DelimChar == ' ' || DelimChar == '=') + if (Arg.length() > Option.length()) { - return Arg.substr(Option.size() + 1); + std::string_view::value_type DelimChar = Arg[Option.length()]; + if (DelimChar == ' ' || DelimChar == '=') + { + return Arg.substr(Option.size() + 1); + } + } + else + { + return ""sv; } } return {}; -- cgit v1.2.3 From 241e4faf64be83711dc509ad8a25ff4e8ae95c12 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 25 Feb 2026 10:15:41 +0100 Subject: HttpService/Frontend improvements (#782) - zenhttp: added `GetServiceUri()`/`GetExternalHost()` - enables code to quickly generate an externally reachable URI for a given service - frontend: improved Uri handling (better defaults) - added support for 404 page (to make it easier to find a good URL) --- src/zenhttp/httpserver.cpp | 24 +++++++++++- src/zenhttp/include/zenhttp/httpserver.h | 16 ++++++++ src/zenhttp/servers/httpasio.cpp | 48 ++++++++++++++++++++---- src/zenhttp/servers/httpmulti.cpp | 10 +++++ src/zenhttp/servers/httpmulti.h | 13 ++++--- src/zenhttp/servers/httpsys.cpp | 49 ++++++++++++++++++++++--- src/zenserver/frontend/frontend.cpp | 57 +++++++++++++++++++++-------- src/zenserver/frontend/html.zip | Bin 183939 -> 279965 bytes src/zenserver/storage/zenstorageserver.cpp | 9 +++++ 9 files changed, 191 insertions(+), 35 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index f2fe4738f..3cefa0ad8 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -1014,7 +1015,28 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) int HttpServer::Initialize(int BasePort, std::filesystem::path DataDir) { - return OnInitialize(BasePort, std::move(DataDir)); + m_EffectivePort = OnInitialize(BasePort, std::move(DataDir)); + m_ExternalHost = OnGetExternalHost(); + return m_EffectivePort; +} + +std::string +HttpServer::OnGetExternalHost() const +{ + return GetMachineName(); +} + +std::string +HttpServer::GetServiceUri(const HttpService* Service) const +{ + if (Service) + { + return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri()); + } + else + { + return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort); + } } void diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 350532126..00cbc6c14 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -219,8 +219,21 @@ public: void RequestExit(); void Close(); + /** Returns a canonical http:// URI for the given service, using the external + * IP and the port the server is actually listening on. Only valid + * after Initialize() has returned successfully. + */ + std::string GetServiceUri(const HttpService* Service) const; + + /** Returns the external host string (IP or hostname) determined during Initialize(). + * Only valid after Initialize() has returned successfully. + */ + std::string_view GetExternalHost() const { return m_ExternalHost; } + private: std::vector m_KnownServices; + int m_EffectivePort = 0; + std::string m_ExternalHost; virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; @@ -228,6 +241,9 @@ private: virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; + +protected: + virtual std::string OnGetExternalHost() const; }; struct HttpServerPluginConfig diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index fbc7fe401..0c0238886 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -506,6 +507,8 @@ public: HttpService* RouteRequest(std::string_view Url); IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + bool IsLoopbackOnly() const; + asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr m_Acceptor; @@ -1601,7 +1604,8 @@ struct HttpAcceptor void StopAccepting() { m_IsStopped = true; } - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } bool IsValid() const { return m_IsValid; } @@ -1975,6 +1979,12 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) return RequestFilter->FilterRequest(Request); } +bool +HttpAsioServerImpl::IsLoopbackOnly() const +{ + return m_Acceptor && m_Acceptor->IsLoopbackOnly(); +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1987,12 +1997,13 @@ public: HttpAsioServer(const AsioConfig& Config); ~HttpAsioServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; private: Event m_ShutdownEvent; @@ -2067,6 +2078,29 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) return m_BasePort; } +std::string +HttpAsioServer::OnGetExternalHost() const +{ + if (m_Impl->IsLoopbackOnly()) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + void HttpAsioServer::OnRun(bool IsInteractive) { diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 310ac9dc0..584e06cbf 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -117,6 +117,16 @@ HttpMultiServer::OnClose() } } +std::string +HttpMultiServer::OnGetExternalHost() const +{ + if (!m_Servers.empty()) + { + return std::string(m_Servers.front()->GetExternalHost()); + } + return HttpServer::OnGetExternalHost(); +} + void HttpMultiServer::AddServer(Ref Server) { diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index 1897587a9..97699828a 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -15,12 +15,13 @@ public: HttpMultiServer(); ~HttpMultiServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; void AddServer(Ref Server); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 6995ffca9..e93ae4853 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ # include # include +# include // for resolving addresses for GetExternalHost namespace zen { @@ -93,12 +95,13 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; WorkerThreadPool& WorkPool(); @@ -2290,6 +2293,40 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + void HttpSysServer::OnRegisterService(HttpService& Service) { diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp index 1cf451e91..579a65c5a 100644 --- a/src/zenserver/frontend/frontend.cpp +++ b/src/zenserver/frontend/frontend.cpp @@ -114,6 +114,8 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) { using namespace std::literals; + ExtendableStringBuilder<256> UriBuilder; + std::string_view Uri = Request.RelativeUriWithExtension(); for (; Uri.length() > 0 && Uri[0] == '/'; Uri = Uri.substr(1)) ; @@ -121,6 +123,11 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) { Uri = "index.html"sv; } + else if (Uri.back() == '/') + { + UriBuilder << Uri << "index.html"sv; + Uri = UriBuilder; + } // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads if (Uri.find("..") != Uri.npos) @@ -145,27 +152,47 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) return Request.WriteResponse(HttpResponseCode::Forbidden); } - // The given content directory overrides any zip-fs discovered in the binary - if (!m_Directory.empty()) - { - auto FullPath = m_Directory / std::filesystem::path(Uri).make_preferred(); - FileContents File = ReadFile(FullPath); - - if (!File.ErrorCode) + auto WriteResponseForUri = [this, + &Request](std::string_view InUri, HttpResponseCode ResponseCode, HttpContentType ContentType) -> bool { + // The given content directory overrides any zip-fs discovered in the binary + if (!m_Directory.empty()) { - return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]); + auto FullPath = m_Directory / std::filesystem::path(InUri).make_preferred(); + FileContents File = ReadFile(FullPath); + + if (!File.ErrorCode) + { + Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); + + return true; + } } - } - if (m_ZipFs) - { - if (IoBuffer FileBuffer = m_ZipFs->GetFile(Uri)) + if (m_ZipFs) { - return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + if (IoBuffer FileBuffer = m_ZipFs->GetFile(InUri)) + { + Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + + return true; + } } - } - Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); + return false; + }; + + if (WriteResponseForUri(Uri, HttpResponseCode::OK, ContentType)) + { + return; + } + else if (WriteResponseForUri("404.html"sv, HttpResponseCode::NotFound, HttpContentType::kHTML)) + { + return; + } + else + { + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); + } } } // namespace zen diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index d70a5a62b..3d90c18a8 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index ff854b72d..3d81db656 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -700,6 +700,15 @@ ZenStorageServer::Run() ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", GetCurrentProcessId()); + if (m_FrontendService) + { + ZEN_INFO("frontend link: {}", m_Http->GetServiceUri(m_FrontendService.get())); + } + else + { + ZEN_INFO("frontend service disabled"); + } + #if ZEN_PLATFORM_WINDOWS if (zen::windows::IsRunningOnWine()) { -- cgit v1.2.3 From d7354c2ad34858d8ee99fb307685956c24abd897 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 25 Feb 2026 18:49:31 +0100 Subject: work around doctest shutdown issues with static CRT (#784) * tweaked doctest.h to avoid shutdown issues due to thread_local variables running destructors after the main thread has torn down everything including the heap * disabled zenserver exit thread waiting since doctest should hopefully not be causing issues during shutdown anymore after my workaround This should help reduce the duration of tests spawning lots of server instances --- src/zenserver/main.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'src') diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index ee783d2a6..571dd3b4f 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -267,6 +267,14 @@ main(int argc, char* argv[]) using namespace zen; using namespace std::literals; + // note: doctest has locally (in thirdparty) been fixed to not cause shutdown + // crashes due to TLS destructors + // + // mimalloc on the other hand might still be causing issues, in which case + // we should work out either how to eliminate the mimalloc dependency or how + // to configure it in a way that doesn't cause shutdown issues + +#if 0 auto _ = zen::MakeGuard([] { // Allow some time for worker threads to unravel, in an effort // to prevent shutdown races in TLS object destruction, mainly due to @@ -277,6 +285,7 @@ main(int argc, char* argv[]) // shutdown crashes observed in some situations. WaitForThreads(1000); }); +#endif enum { -- cgit v1.2.3 From c1838da092c31c4ebe1e9c3f3909a1bef37d34a2 Mon Sep 17 00:00:00 2001 From: zousar Date: Thu, 26 Feb 2026 10:58:50 -0700 Subject: updatefrontend --- src/zenserver/frontend/html.zip | Bin 183939 -> 238188 bytes 1 file changed, 0 insertions(+), 0 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index d70a5a62b..4767029c0 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ -- cgit v1.2.3 From 7c7e25d55ebb593aaa6a42903e2db4629f3b7051 Mon Sep 17 00:00:00 2001 From: zousar Date: Thu, 26 Feb 2026 11:09:36 -0700 Subject: updatefrontend --- src/zenserver/frontend/html.zip | Bin 279965 -> 238188 bytes 1 file changed, 0 insertions(+), 0 deletions(-) (limited to 'src') diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index 3d90c18a8..4767029c0 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ -- cgit v1.2.3 From 91885b9fc6b1954d78d14bdf39e2ba91a5aa9f67 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Thu, 26 Feb 2026 20:14:07 +0100 Subject: adding HttpClient tests (#785) Add comprehensive `HttpClient` test suite. Covers: - **HTTP verbs** -- GET, POST, PUT, DELETE, HEAD dispatch correctly - **GET/POST/PUT/Upload/Download** -- payload round-trips (IoBuffer, CbObject, CompositeBuffer), content types, large payloads, file-spill downloads - **Status codes** -- 2xx/4xx/5xx classification, exact code matching - **Response API** -- IsSuccess, AsText, AsObject, ToText, ErrorMessage, ThrowError - **Error handling** -- connection refused, request timeout, nonexistent endpoints - **Session management** -- default ID, SetSessionId, reset to zero - **Authentication** -- token provider, expired tokens, bearer verification - **Content type detection** -- text, JSON, binary, CbObject - **Request metadata** -- elapsed time, upload/download byte counts - **Retry logic** -- retry after transient 503s, no-retry baseline - **Latency measurement** -- MeasureLatency against live and unreachable servers - **KeyValueMap** -- construction from pairs, string_views, initializer lists - **Transport-level faults (GET)** -- connection reset/close before response, partial headers, truncated body, mid-body reset, stalled response timeout, retry after RST - **Transport-level faults (POST)** -- server reset/close before consuming body, mid-body reset, early 503 without consuming upload, stalled upload timeout, retry with large body after transient failures Also adds zenhttp-test to the xmake test runner (xmake test --run=http). --- src/zenhttp/httpclient_test.cpp | 1362 ++++++++++++++++++++++++++++++ src/zenhttp/include/zenhttp/httpclient.h | 3 +- src/zenhttp/zenhttp.cpp | 1 + 3 files changed, 1365 insertions(+), 1 deletion(-) create mode 100644 src/zenhttp/httpclient_test.cpp (limited to 'src') diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp new file mode 100644 index 000000000..509b56371 --- /dev/null +++ b/src/zenhttp/httpclient_test.cpp @@ -0,0 +1,1362 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include + +#if ZEN_WITH_TESTS + +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include "servers/httpasio.h" + +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Test service + +class HttpClientTestService : public HttpService +{ +public: + HttpClientTestService() + { + m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool { + for (char C : Str) + { + if (C < '0' || C > '9') + { + return false; + } + } + return !Str.empty(); + }); + + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/headers", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + CbObjectWriter Writer; + if (!Auth.empty()) + { + Writer.AddString("Authorization", Auth); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddString("message", "test"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "created", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created"); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "content-type/text", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}"); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/binary", + [](HttpRouterRequest& Req) { + uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF}; + IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data)); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/cbobject", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddString("type", "cbobject"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "auth/bearer", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + if (Auth.starts_with("Bearer ") && Auth.size() > 7) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized"); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Sleep(2000); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + constexpr size_t Size = 64 * 1024; + IoBuffer Buf(Size); + uint8_t* Ptr = static_cast(Buf.MutableData()); + for (size_t i = 0; i < Size; ++i) + { + Ptr[i] = static_cast(i & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status/{statuscode}", + [](HttpRouterRequest& Req) { + std::string_view CodeStr = Req.GetCapture(1); + int Code = std::stoi(std::string{CodeStr}); + const HttpResponseCode ResponseCode = static_cast(Code); + Req.ServerRequest().WriteResponse(ResponseCode); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "attempt-counter", + [this](HttpRouterRequest& Req) { + uint32_t Count = m_AttemptCounter.fetch_add(1); + if (Count < m_FailCount) + { + Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries"); + } + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + + void ResetAttemptCounter(uint32_t FailCount) + { + m_AttemptCounter.store(0); + m_FailCount = FailCount; + } + +private: + HttpRequestRouter m_Router; + std::atomic m_AttemptCounter{0}; + uint32_t m_FailCount = 2; +}; + +////////////////////////////////////////////////////////////////////////// +// Test server fixture + +struct TestServerFixture +{ + HttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref Server; + std::thread ServerThread; + int Port = -1; + + TestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(7600, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~TestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_CASE("httpclient.verbs") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + HttpClient::Response Resp = Client.Get("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + HttpClient::Response Resp = Client.Put("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + HttpClient::Response Resp = Client.Delete("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + HttpClient::Response Resp = Client.Head("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("httpclient.get") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET with auth header via echo") + { + HttpClient::Response Resp = + Client.Get("/api/test/echo/headers", std::pair("Authorization", "Bearer test-token-123")); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123"); + } + + SUBCASE("GET returning CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("GET large payload") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + + const uint8_t* Data = static_cast(Resp.ResponsePayload.GetData()); + bool Valid = true; + for (size_t i = 0; i < 64 * 1024; ++i) + { + if (Data[i] != static_cast(i & 0xFF)) + { + Valid = false; + break; + } + } + CHECK(Valid); + } +} + +TEST_CASE("httpclient.post") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("POST with IoBuffer payload echo round-trip") + { + const char* Payload = "test payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload data"); + } + + SUBCASE("POST with IoBuffer and explicit content type") + { + const char* Payload = "{\"key\":\"value\"}"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); + } + + SUBCASE("POST with CbObject payload round-trip") + { + CbObjectWriter Writer; + Writer.AddBool("enabled", true); + Writer.AddString("name", "testobj"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj); + CHECK(Resp.IsSuccess()); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["enabled"].AsBool() == true); + CHECK_EQ(RoundTripped["name"].AsString(), "testobj"); + } + + SUBCASE("POST with CompositeBuffer payload") + { + const char* Part1 = "hello "; + const char* Part2 = "composite"; + IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1)); + IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2)); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello composite"); + } + + SUBCASE("POST with custom headers") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}); + CHECK(Resp.IsSuccess()); + } + + SUBCASE("POST with empty body to nocontent endpoint") + { + HttpClient::Response Resp = Client.Post("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("httpclient.put") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("PUT with IoBuffer payload echo round-trip") + { + const char* Payload = "put payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload data"); + } + + SUBCASE("PUT with parameters only") + { + HttpClient::Response Resp = Client.Put("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } + + SUBCASE("PUT to created endpoint") + { + const char* Payload = "new resource"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/created", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created); + CHECK_EQ(Resp.AsText(), "resource created"); + } +} + +TEST_CASE("httpclient.upload") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("Upload IoBuffer") + { + constexpr size_t Size = 128 * 1024; + IoBuffer Blob = CreateSemiRandomBlob(Size); + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), Size); + } + + SUBCASE("Upload CompositeBuffer") + { + IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024); + IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.download") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("Download small payload stays in memory") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.status-codes") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("2xx are success") + { + CHECK(Client.Get("/api/test/status/200").IsSuccess()); + CHECK(Client.Get("/api/test/status/201").IsSuccess()); + CHECK(Client.Get("/api/test/status/204").IsSuccess()); + } + + SUBCASE("4xx are not success") + { + CHECK(!Client.Get("/api/test/status/400").IsSuccess()); + CHECK(!Client.Get("/api/test/status/401").IsSuccess()); + CHECK(!Client.Get("/api/test/status/403").IsSuccess()); + CHECK(!Client.Get("/api/test/status/404").IsSuccess()); + CHECK(!Client.Get("/api/test/status/409").IsSuccess()); + } + + SUBCASE("5xx are not success") + { + CHECK(!Client.Get("/api/test/status/500").IsSuccess()); + CHECK(!Client.Get("/api/test/status/502").IsSuccess()); + CHECK(!Client.Get("/api/test/status/503").IsSuccess()); + } + + SUBCASE("status code values match") + { + CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK); + CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created); + CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent); + CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest); + CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized); + CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden); + CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound); + CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict); + CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError); + CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway); + CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.response") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("IsSuccess and operator bool for success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(static_cast(Resp)); + } + + SUBCASE("IsSuccess and operator bool for failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + CHECK(!Resp.IsSuccess()); + CHECK(!static_cast(Resp)); + } + + SUBCASE("AsText returns body") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("AsText returns empty for no-content") + { + HttpClient::Response Resp = Client.Get("/api/test/nocontent"); + CHECK(Resp.AsText().empty()); + } + + SUBCASE("AsObject parses CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("AsObject returns empty for non-CB content") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CbObject Obj = Resp.AsObject(); + CHECK(!Obj); + } + + SUBCASE("ToText for text content") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK_EQ(Resp.ToText(), "plain text"); + } + + SUBCASE("ToText for CbObject content") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + std::string Text = Resp.ToText(); + CHECK(!Text.empty()); + // ToText for CbObject converts to JSON string representation + CHECK(Text.find("ok") != std::string::npos); + CHECK(Text.find("test") != std::string::npos); + } + + SUBCASE("ErrorMessage includes status code on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + std::string Msg = Resp.ErrorMessage("test-prefix"); + CHECK(Msg.find("test-prefix") != std::string::npos); + CHECK(Msg.find("404") != std::string::npos); + } + + SUBCASE("ThrowError throws on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/500"); + CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError); + } + + SUBCASE("ThrowError does not throw on success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_NOTHROW(Resp.ThrowError("test")); + } + + SUBCASE("HttpClientError carries response code") + { + HttpClient::Response Resp = Client.Get("/api/test/status/403"); + try + { + Resp.ThrowError("test"); + CHECK(false); // should not reach + } + catch (const HttpClientError& Err) + { + CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden); + } + } +} + +TEST_CASE("httpclient.error-handling") +{ + SUBCASE("Connection refused") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("Request timeout") + { + TestServerFixture Fixture; + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/slow"); + CHECK(!Resp.IsSuccess()); + } + + SUBCASE("Nonexistent endpoint returns failure") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/does-not-exist"); + CHECK(!Resp.IsSuccess()); + } +} + +TEST_CASE("httpclient.session") +{ + TestServerFixture Fixture; + + SUBCASE("Default session ID is non-empty") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.GetSessionId().empty()); + } + + SUBCASE("SetSessionId changes ID") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + std::string OldId = std::string(Client.GetSessionId()); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + CHECK_NE(Client.GetSessionId(), OldId); + } + + SUBCASE("SetSessionId with Zero resets") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + Client.SetSessionId(Oid::Zero); + // After resetting, should get a session string (not empty, not the custom one) + CHECK(!Client.GetSessionId().empty()); + CHECK_NE(Client.GetSessionId(), NewId.ToString()); + } +} + +TEST_CASE("httpclient.authentication") +{ + TestServerFixture Fixture; + + SUBCASE("Authenticate returns false without provider") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Authenticate returns true with valid token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "valid-token", + .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(Client.Authenticate()); + } + + SUBCASE("Authenticate returns false with expired token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "expired-token", + .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Bearer token verified by auth endpoint") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response AuthResp = + Client.Get("/api/test/auth/bearer", std::pair("Authorization", "Bearer my-secret-token")); + CHECK(AuthResp.IsSuccess()); + CHECK_EQ(AuthResp.AsText(), "authenticated"); + } + + SUBCASE("Request without token to auth endpoint gets 401") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/auth/bearer"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized); + } +} + +TEST_CASE("httpclient.content-types") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("text content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + + SUBCASE("JSON content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/json"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON); + } + + SUBCASE("binary content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/binary"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary); + } + + SUBCASE("CbObject content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject); + } +} + +TEST_CASE("httpclient.metadata") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("ElapsedSeconds is positive") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.ElapsedSeconds > 0.0); + } + + SUBCASE("DownloadedBytes populated for GET") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.DownloadedBytes > 0); + } + + SUBCASE("UploadedBytes populated for POST with payload") + { + const char* Payload = "some upload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK(Resp.UploadedBytes > 0); + } +} + +TEST_CASE("httpclient.retry") +{ + TestServerFixture Fixture; + + SUBCASE("Retry succeeds after transient failures") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "success after retries"); + } + + SUBCASE("No retry returns 503 immediately") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 0; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.measurelatency") +{ + SUBCASE("Successful measurement against live server") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(Result.Success); + CHECK(Result.LatencySeconds > 0.0); + } + + SUBCASE("Failed measurement against unreachable port") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(!Result.Success); + CHECK(!Result.FailureReason.empty()); + } +} + +TEST_CASE("httpclient.keyvaluemap") +{ + SUBCASE("Default construction is empty") + { + HttpClient::KeyValueMap Map; + CHECK(Map->empty()); + } + + SUBCASE("Construction from pair") + { + HttpClient::KeyValueMap Map(std::pair("key", "value")); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from string_view pair") + { + HttpClient::KeyValueMap Map(std::pair("key"sv, "value"sv)); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from initializer list") + { + HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}}); + CHECK_EQ(Map->size(), 2u); + CHECK_EQ(Map->at("a"), "1"); + CHECK_EQ(Map->at("b"), "2"); + } +} + +////////////////////////////////////////////////////////////////////////// +// Transport fault testing + +static std::string +MakeRawHttpResponse(int StatusCode, std::string_view Body) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + StatusCode, + Body.size(), + Body); +} + +static std::string +MakeRawHttpHeaders(int StatusCode, size_t ContentLength) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Length: {}\r\n" + "\r\n", + StatusCode, + ContentLength); +} + +static void +DrainHttpRequest(asio::ip::tcp::socket& Socket) +{ + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); +} + +static void +DrainFullHttpRequest(asio::ip::tcp::socket& Socket) +{ + // Read until end of headers + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Extract headers to find Content-Length + std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + + size_t ContentLength = 0; + auto Pos = Headers.find("Content-Length: "); + if (Pos == std::string::npos) + { + Pos = Headers.find("content-length: "); + } + if (Pos != std::string::npos) + { + size_t ValStart = Pos + 16; // length of "Content-Length: " + size_t ValEnd = Headers.find("\r\n", ValStart); + if (ValEnd != std::string::npos) + { + ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart)); + } + } + + // Calculate how many body bytes were already read past the header boundary. + // asio::read_until may read past the delimiter, so Buf.data() contains everything read. + size_t HeaderEnd = Headers.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0; + size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0; + + if (Remaining > 0) + { + std::vector BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +static void +DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead) +{ + // Read headers first + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Determine how many body bytes were already buffered past headers + std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + size_t HeaderEnd = All.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0; + + if (BodyBytesInBuf < BytesToRead) + { + size_t Remaining = BytesToRead - BodyBytesInBuf; + std::vector BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +struct FaultTcpServer +{ + using FaultHandler = std::function; + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + FaultHandler m_Handler; + std::thread m_Thread; + int m_Port; + + explicit FaultTcpServer(FaultHandler Handler) + : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0)) + , m_Handler(std::move(Handler)) + { + m_Port = m_Acceptor.local_endpoint().port(); + StartAccept(); + m_Thread = std::thread([this]() { m_IoContext.run(); }); + } + + ~FaultTcpServer() + { + std::error_code Ec; + m_Acceptor.close(Ec); + m_IoContext.stop(); + if (m_Thread.joinable()) + { + m_Thread.join(); + } + } + + FaultTcpServer(const FaultTcpServer&) = delete; + FaultTcpServer& operator=(const FaultTcpServer&) = delete; + + void StartAccept() + { + m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) { + if (!Ec) + { + m_Handler(Socket); + } + if (m_Acceptor.is_open()) + { + StartAccept(); + } + }); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +TEST_CASE("httpclient.transport-faults") +{ + SUBCASE("connection reset before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection closed before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("partial headers then close") + { + // libcurl parses the status line (200 OK) and accepts the response even though + // headers are truncated mid-field. It reports success with an empty body instead + // of an error. Ideally this should be detected as a transport failure. + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Partial = "HTTP/1.1 200 OK\r\nContent-"; + std::error_code Ec; + asio::write(Socket, asio::buffer(Partial), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + WARN(!Resp.IsSuccess()); + WARN(Resp.Error.has_value()); + } + + SUBCASE("truncated body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(100, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection reset mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 10000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(1000, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("stalled response triggers timeout") + { + std::atomic StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("retry succeeds after transient failures") + { + std::atomic ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + DrainHttpRequest(Socket); + if (N < 2) + { + // Connection reset produces NETWORK_SEND_FAILURE which is retryable + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + std::string Response = MakeRawHttpResponse(200, "recovered"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "recovered"); + } +} + +TEST_CASE("httpclient.transport-faults-post") +{ + constexpr size_t kPostBodySize = 256 * 1024; + + auto MakePostBody = []() -> IoBuffer { + IoBuffer Buf(kPostBodySize); + uint8_t* Ptr = static_cast(Buf.MutableData()); + for (size_t i = 0; i < kPostBodySize; ++i) + { + Ptr[i] = static_cast(i & 0xFF); + } + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + + SUBCASE("POST: server resets before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server closes before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server resets mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainPartialBody(Socket, 8 * 1024); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: early error response before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(503, "service busy"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + // With a large upload body, the server may RST the connection before the client + // reads the 503 response. Either outcome is valid: the client sees the HTTP 503 + // status, or it sees a transport-level error from the RST. + CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value())); + } + + SUBCASE("POST: stalled upload triggers timeout") + { + std::atomic StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + // Stop reading body — TCP window will fill and client send will stall + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(2000); + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("POST: retry with large body after transient failure") + { + std::atomic ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + if (N < 2) + { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + DrainFullHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "upload-ok"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "upload-ok"); + } +} + +void +httpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 7a129a98c..336a3deee 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -269,6 +269,7 @@ struct LatencyTestResult LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url); -void httpclient_forcelink(); // internal +void httpclient_forcelink(); // internal +void httpclient_test_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 0b5408453..ad14ecb8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -16,6 +16,7 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); } -- cgit v1.2.3 From 4d5caf7d011bf73c7b90ff1d8c1cfdad817fa2f5 Mon Sep 17 00:00:00 2001 From: Martin Ridgers Date: Fri, 27 Feb 2026 11:47:00 +0100 Subject: Ported "lane trace" feature from UE (by way of IAX) (#771) * Ported "lane trace" feature from UE (by way of IAX) --- src/zencore/include/zencore/trace.h | 1 + 1 file changed, 1 insertion(+) (limited to 'src') diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h index 99a565151..d17e018ea 100644 --- a/src/zencore/include/zencore/trace.h +++ b/src/zencore/include/zencore/trace.h @@ -13,6 +13,7 @@ ZEN_THIRD_PARTY_INCLUDES_START # define TRACE_IMPLEMENT 0 #endif #include +#include #undef TRACE_IMPLEMENT ZEN_THIRD_PARTY_INCLUDES_END -- cgit v1.2.3 From 9e7019aa16b19cd87aa6af3ef39825edb039c8be Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 27 Feb 2026 13:12:10 +0100 Subject: add support in http client to accept multi-range responses (#788) * add support in http client to accept multi-range responses --- src/zenhttp/clients/httpclientcommon.cpp | 315 ++++++++++++ src/zenhttp/clients/httpclientcommon.h | 109 ++++- src/zenhttp/clients/httpclientcpr.cpp | 530 ++++++++++++--------- src/zenhttp/clients/httpclientcpr.h | 16 +- src/zenhttp/httpclient.cpp | 41 ++ src/zenhttp/include/zenhttp/httpclient.h | 14 + .../builds/buildstorageoperations.cpp | 152 +++--- 7 files changed, 858 insertions(+), 319 deletions(-) (limited to 'src') diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 47425e014..312ca16d2 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -382,6 +382,178 @@ namespace detail { return Result; } + MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {} + + bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue) + { + std::string LowerCaseValue = ToLower(ContentTypeHeaderValue); + if (LowerCaseValue.starts_with("multipart/byteranges")) + { + size_t BoundaryPos = LowerCaseValue.find("boundary="); + if (BoundaryPos != std::string::npos) + { + // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string + std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + if (!BoundaryName.empty()) + { + size_t BoundaryEnd = std::string::npos; + while (BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } + if (!BoundaryName.empty()) + { + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + { + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) + { + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); + return true; + } + } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } + } + } + } + } + return false; + } + + void MultipartBoundaryParser::InternalParseInput(std::string_view data) + { + size_t ScanPos = 0; + while (ScanPos < data.length()) + { + const char ScanChar = data[ScanPos]; + if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) + { + BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + break; + } + } + + BoundaryHeader.Append(ScanChar); + + HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + + if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset(); + const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset(); + const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset; + std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength)); + + uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1; + + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType = HttpContentType::kBinary; + + ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) { + const std::pair KeyAndValue = GetHeaderKeyAndValue(Line); + const std::string_view Key = KeyAndValue.first; + const std::string_view Value = KeyAndValue.second; + if (Key == "Content-Range") + { + std::pair ContentRange = ParseContentRange(Value); + if (ContentRange.second != 0) + { + RangeOffset = ContentRange.first; + RangeLength = ContentRange.second; + } + } + else if (Key == "Content-Type") + { + ContentType = ParseContentType(Value); + } + + return true; + }); + + if (RangeLength > 0) + { + Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload, + .RangeOffset = RangeOffset, + .RangeLength = RangeLength, + .ContentType = ContentType}); + } + + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + } + } + else + { + BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar); + } + ScanPos++; + } + PayloadOffset += data.length(); + } + + std::pair GetHeaderKeyAndValue(std::string_view HeaderString) + { + size_t DelimiterPos = HeaderString.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string_view Key = HeaderString.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string_view Value = HeaderString.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + return std::make_pair(Key, Value); + } + return std::make_pair(HeaderString, std::string_view{}); + } + + std::pair ParseContentRange(std::string_view Value) + { + if (Value.starts_with("bytes ")) + { + size_t RangeSplitPos = Value.find('-', 6); + if (RangeSplitPos != std::string::npos) + { + size_t RangeEndLength = Value.find('/', RangeSplitPos + 1); + if (RangeEndLength == std::string::npos) + { + RangeEndLength = Value.length() - (RangeSplitPos + 1); + } + else + { + RangeEndLength = RangeEndLength - (RangeSplitPos + 1); + } + std::optional RequestedRangeStart = ParseInt(Value.substr(6, RangeSplitPos - 6)); + std::optional RequestedRangeEnd = ParseInt(Value.substr(RangeSplitPos + 1, RangeEndLength)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + uint64_t RangeOffset = RequestedRangeStart.value(); + uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1; + return std::make_pair(RangeOffset, RangeLength); + } + } + } + return {0, 0}; + } + } // namespace detail } // namespace zen @@ -470,5 +642,148 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("MultipartBoundaryParser") +{ + uint64_t Range1Offset = 2638; + uint64_t Range1Length = (5111437 - Range1Offset) + 1; + + uint64_t Range2Offset = 5118199; + uint64_t Range2Length = (9147741 - Range2Offset) + 1; + + std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229"; + std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\""; + + { + std::string_view Example1 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/44369878\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample1; + ParserExample1.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 7; + for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow) + { + ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow))); + } + + CHECK(ParserExample1.Boundaries.size() == 2); + + CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example2 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample2; + ParserExample2.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow) + { + std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow)); + ParserExample2.ParseInput(Window); + } + + CHECK(ParserExample2.Boundaries.size() == 2); + + CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example3 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita"; + + detail::MultipartBoundaryParser ParserExample3; + ParserExample3.Init(ContentTypeHeaderValue2); + + const size_t InputWindow = 31; + for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow) + { + ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow))); + } + + CHECK(ParserExample3.Boundaries.size() == 2); + + CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example4 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "Not: really\r\n" + "\r\n" + "datadatadatadata" + "\r\n--000000000bait0019229\r\n" + "\r\n--00\r\n--000000000bait001922\r\n" + "\r\n\r\n\r\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "Content-Type: application/x-ue-comp\r\n" + "ditaditadita" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n---\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample4; + ParserExample4.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow) + { + std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow)); + ParserExample4.ParseInput(Window); + } + + CHECK(ParserExample4.Boundaries.size() == 2); + + CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length); + } +} + } // namespace zen #endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 1d0b7f9ea..8bb1e9268 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include @@ -87,7 +88,7 @@ namespace detail { std::error_code Write(std::string_view DataString); IoBuffer DetachToIoBuffer(); IoBuffer BorrowIoBuffer(); - inline uint64_t GetSize() const { return m_WriteOffset; } + inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; } void ResetWritePos(uint64_t WriteOffset); private: @@ -143,6 +144,112 @@ namespace detail { uint64_t m_BytesLeftInSegment; }; + class IncrementalStringMatcher + { + public: + enum class EMatchState + { + None, + Partial, + Complete + }; + + EMatchState MatchState = EMatchState::None; + + IncrementalStringMatcher() {} + + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) {} + + void Init(std::string&& InMatchString) { MatchString = std::move(InMatchString); } + + void Reset() + { + MatchLength = 0; + MatchStartOffset = 0; + MatchState = EMatchState::None; + } + + inline uint64_t GetMatchEndOffset() const + { + if (MatchState == EMatchState::Complete) + { + return MatchStartOffset + MatchString.length(); + } + return 0; + } + + inline uint64_t GetMatchStartOffset() const + { + ZEN_ASSERT(MatchState == EMatchState::Complete); + return MatchStartOffset; + } + + void Match(uint64_t Offset, char C) + { + ZEN_ASSERT_SLOW(!MatchString.empty()); + + if (MatchState == EMatchState::Complete) + { + Reset(); + } + if (C == MatchString[MatchLength]) + { + if (MatchLength == 0) + { + MatchStartOffset = Offset; + } + MatchLength++; + if (MatchLength == MatchString.length()) + { + MatchState = EMatchState::Complete; + } + else + { + MatchState = EMatchState::Partial; + } + } + else if (MatchLength != 0) + { + Reset(); + Match(Offset, C); + } + else + { + Reset(); + } + } + inline const std::string& GetMatchString() const { return MatchString; } + + private: + std::string MatchString; + + uint64_t MatchLength = 0; + uint64_t MatchStartOffset = 0; + }; + + class MultipartBoundaryParser + { + public: + std::vector Boundaries; + + MultipartBoundaryParser(); + bool Init(const std::string_view ContentTypeHeaderValue); + inline void ParseInput(std::string_view data) { InternalParseInput(data); } + + private: + IncrementalStringMatcher BoundaryBeginMatcher; + IncrementalStringMatcher BoundaryEndMatcher; + IncrementalStringMatcher HeaderEndMatcher; + + ExtendableStringBuilder<64> BoundaryHeader; + uint64_t PayloadOffset = 0; + + void InternalParseInput(std::string_view data); + }; + + std::pair GetHeaderKeyAndValue(std::string_view HeaderString); + std::pair ParseContentRange(std::string_view Value); + } // namespace detail } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 5d92b3b6b..6bc63db09 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -162,10 +162,11 @@ CprHttpClient::~CprHttpClient() } HttpClient::Response -CprHttpClient::ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload) +CprHttpClient::ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly @@ -174,7 +175,6 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); - ResponseBuffer.SetContentType(ContentType); } @@ -188,16 +188,26 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, } } + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + .ElapsedSeconds = HttpResponse.elapsed, + .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response -CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload) +CprHttpClient::CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) @@ -235,7 +245,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe } else { - return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload)); + return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } @@ -896,236 +906,280 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::string PayloadString; std::unique_ptr PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return false; - } - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional RequestedRangeStart = - ParseInt(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional RequestedRangeEnd = ParseInt(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional ContentLength = ParseInt(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) - { - PayloadFile = std::make_unique(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::vector> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile); - - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (IsMultiRangeResponse) + { + BoundaryParser.ParseInput(data); + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional RequestedRangeStart = ParseInt(RangeString.substr(0, RangeSplitPos)); + std::optional RequestedRangeEnd = ParseInt(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + cpr::Response Response; + { + std::vector> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair Header = detail::GetHeaderKeyAndValue(header); + if (Header.first == "Content-Length"sv) + { + std::optional ContentLength = ParseInt(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + else if (Header.first == "Content-Type") + { + IsMultiRangeResponse = BoundaryParser.Init(Header.second); + if (!IsMultiRangeResponse) + { + ContentType = ParseContentType(Header.second); + } + } + else if (Header.first == "Content-Range") + { + if (!IsMultiRangeResponse) + { + std::pair Range = detail::ParseContentRange(Header.second); + if (Range.second != 0) + { + BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = ContentType}); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::vector> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair Header = detail::GetHeaderKeyAndValue(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Response), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); } } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 40af53b5d..cf2d3bd14 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -157,12 +157,16 @@ private: bool ValidatePayload(cpr::Response& Response, std::unique_ptr& PayloadFile); - HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload); - - HttpClient::Response ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload); + HttpClient::Response CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector&& BoundaryPositions = {}); + + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector&& BoundaryPositions); }; } // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 078e27b34..998eb27ea 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -104,6 +104,47 @@ HttpClientBase::GetAccessToken() ////////////////////////////////////////////////////////////////////////// +std::vector> +HttpClient::Response::GetRanges(std::span> OffsetAndLengthPairs) const +{ + std::vector> Result; + Result.reserve(OffsetAndLengthPairs.size()); + if (Ranges.empty()) + { + for (const std::pair& Range : OffsetAndLengthPairs) + { + Result.emplace_back(std::make_pair(Range.first, Range.second)); + } + return Result; + } + + auto BoundaryIt = Ranges.begin(); + auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); + while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) + { + uint64_t Offset = OffsetAndLengthPairIt->first; + uint64_t Length = OffsetAndLengthPairIt->second; + while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength) + { + BoundaryIt++; + if (BoundaryIt == Ranges.end()) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + } + if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset; + uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange; + Result.emplace_back(std::make_pair(RangePayloadOffset, Length)); + + OffsetAndLengthPairIt++; + } + return Result; +} + CbObject HttpClient::Response::AsObject() const { diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 336a3deee..41a7ce621 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -179,6 +179,20 @@ public: // The elapsed time in seconds for the request to execute double ElapsedSeconds; + struct MultipartBoundary + { + uint64_t OffsetInPayload; + uint64_t RangeOffset; + uint64_t RangeLength; + HttpContentType ContentType; + }; + + // Ranges will map out all recevied ranges, both single and multi-range responses + // If no range was requested Ranges will be empty + std::vector Ranges; + + std::vector> GetRanges(std::span> OffsetAndLengthPairs) const; + // This contains any errors from the HTTP stack. It won't contain information on // why the server responded with a non-success HTTP status, that may be gleaned // from the response payload itself depending on what the server provides. diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 5219e86d8..08a896f37 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -7828,105 +7828,109 @@ TEST_CASE("buildstorageoperations.memorychunkingcache") TEST_CASE("buildstorageoperations.upload.multipart") { - using namespace buildstorageoperations_testutils; + // Disabled since it relies on authentication and specific block being present in cloud storage + if (false) + { + using namespace buildstorageoperations_testutils; - FastRandom BaseRandom; + FastRandom BaseRandom; - const size_t FileCount = 11; + const size_t FileCount = 11; - const std::string Paths[FileCount] = {{"file_1"}, - {"file_2.exe"}, - {"file_3.txt"}, - {"dir_1/dir1_file_1.exe"}, - {"dir_1/dir1_file_2.pdb"}, - {"dir_1/dir1_file_3.txt"}, - {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"}, - {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"}, - {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"}, - {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"}, - {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}}; - const uint64_t Sizes[FileCount] = - {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u}; + const std::string Paths[FileCount] = {{"file_1"}, + {"file_2.exe"}, + {"file_3.txt"}, + {"dir_1/dir1_file_1.exe"}, + {"dir_1/dir1_file_2.pdb"}, + {"dir_1/dir1_file_3.txt"}, + {"dir_2/dir2_dir1/dir2_dir1_file_1.exe"}, + {"dir_2/dir2_dir1/dir2_dir1_file_2.pdb"}, + {"dir_2/dir2_dir1/dir2_dir1_file_3.dll"}, + {"dir_2/dir2_dir2/dir2_dir2_file_1.txt"}, + {"dir_2/dir2_dir2/dir2_dir2_file_2.json"}}; + const uint64_t Sizes[FileCount] = + {6u * 1024u, 0, 798, 19u * 1024u, 7u * 1024u, 93, 31u * 1024u, 17u * 1024u, 13u * 1024u, 2u * 1024u, 3u * 1024u}; - ScopedTemporaryDirectory SourceFolder; - TestState State(SourceFolder.Path()); - State.Initialize(); - State.CreateSourceData("source", Paths, Sizes); + ScopedTemporaryDirectory SourceFolder; + TestState State(SourceFolder.Path()); + State.Initialize(); + State.CreateSourceData("source", Paths, Sizes); - std::span ManifestFiles1(Paths); - ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2); + std::span ManifestFiles1(Paths); + ManifestFiles1 = ManifestFiles1.subspan(0, FileCount / 2); - std::span ManifestSizes1(Sizes); - ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2); + std::span ManifestSizes1(Sizes); + ManifestSizes1 = ManifestSizes1.subspan(0, FileCount / 2); - std::span ManifestFiles2(Paths); - ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1); + std::span ManifestFiles2(Paths); + ManifestFiles2 = ManifestFiles2.subspan(FileCount / 2 - 1); - std::span ManifestSizes2(Sizes); - ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1); + std::span ManifestSizes2(Sizes); + ManifestSizes2 = ManifestSizes2.subspan(FileCount / 2 - 1); - const Oid BuildPart1Id = Oid::NewOid(); - const std::string BuildPart1Name = "part1"; - const Oid BuildPart2Id = Oid::NewOid(); - const std::string BuildPart2Name = "part2"; - { - CbObjectWriter Writer; - Writer.BeginObject("parts"sv); + const Oid BuildPart1Id = Oid::NewOid(); + const std::string BuildPart1Name = "part1"; + const Oid BuildPart2Id = Oid::NewOid(); + const std::string BuildPart2Name = "part2"; { - Writer.BeginObject(BuildPart1Name); + CbObjectWriter Writer; + Writer.BeginObject("parts"sv); { - Writer.AddObjectId("partId"sv, BuildPart1Id); - Writer.BeginArray("files"sv); - for (const std::string& ManifestFile : ManifestFiles1) + Writer.BeginObject(BuildPart1Name); { - Writer.AddString(ManifestFile); + Writer.AddObjectId("partId"sv, BuildPart1Id); + Writer.BeginArray("files"sv); + for (const std::string& ManifestFile : ManifestFiles1) + { + Writer.AddString(ManifestFile); + } + Writer.EndArray(); // files } - Writer.EndArray(); // files - } - Writer.EndObject(); // part1 + Writer.EndObject(); // part1 - Writer.BeginObject(BuildPart2Name); - { - Writer.AddObjectId("partId"sv, BuildPart2Id); - Writer.BeginArray("files"sv); - for (const std::string& ManifestFile : ManifestFiles2) + Writer.BeginObject(BuildPart2Name); { - Writer.AddString(ManifestFile); + Writer.AddObjectId("partId"sv, BuildPart2Id); + Writer.BeginArray("files"sv); + for (const std::string& ManifestFile : ManifestFiles2) + { + Writer.AddString(ManifestFile); + } + Writer.EndArray(); // files } - Writer.EndArray(); // files + Writer.EndObject(); // part2 } - Writer.EndObject(); // part2 - } - Writer.EndObject(); // parts + Writer.EndObject(); // parts - ExtendableStringBuilder<1024> Manifest; - CompactBinaryToJson(Writer.Save(), Manifest); - WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size())); - } + ExtendableStringBuilder<1024> Manifest; + CompactBinaryToJson(Writer.Save(), Manifest); + WriteFile(State.RootPath / "manifest.json", IoBuffer(IoBuffer::Wrap, Manifest.Data(), Manifest.Size())); + } - const Oid BuildId = Oid::NewOid(); + const Oid BuildId = Oid::NewOid(); - auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json"); + auto Result = State.Upload(BuildId, {}, {}, "source", State.RootPath / "manifest.json"); - CHECK_EQ(Result.size(), 2u); - CHECK_EQ(Result[0].first, BuildPart1Id); - CHECK_EQ(Result[0].second, BuildPart1Name); - CHECK_EQ(Result[1].first, BuildPart2Id); - CHECK_EQ(Result[1].second, BuildPart2Name); - State.ValidateUpload(BuildId, Result); + CHECK_EQ(Result.size(), 2u); + CHECK_EQ(Result[0].first, BuildPart1Id); + CHECK_EQ(Result[0].second, BuildPart1Name); + CHECK_EQ(Result[1].first, BuildPart2Id); + CHECK_EQ(Result[1].second, BuildPart2Name); + State.ValidateUpload(BuildId, Result); - FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false); - State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent); + FolderContent DownloadContent = State.Download(BuildId, Oid::Zero, {}, "download", /* Append */ false); + State.ValidateDownload(Paths, Sizes, "source", "download", DownloadContent); - FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false); - State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent); + FolderContent Part1DownloadContent = State.Download(BuildId, BuildPart1Id, {}, "download_part1", /* Append */ false); + State.ValidateDownload(ManifestFiles1, ManifestSizes1, "source", "download_part1", Part1DownloadContent); - FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false); - State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent); + FolderContent Part2DownloadContent = State.Download(BuildId, Oid::Zero, BuildPart2Name, "download_part2", /* Append */ false); + State.ValidateDownload(ManifestFiles2, ManifestSizes2, "source", "download_part2", Part2DownloadContent); - (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false); - FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true); - State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent); + (void)State.Download(BuildId, BuildPart1Id, BuildPart1Name, "download_part1+2", /* Append */ false); + FolderContent Part1And2DownloadContent = State.Download(BuildId, BuildPart2Id, {}, "download_part1+2", /* Append */ true); + State.ValidateDownload(Paths, Sizes, "source", "download_part1+2", Part1And2DownloadContent); + } } void -- cgit v1.2.3 From 226ba2cf432ae6c0a787c6156a172f343fc71887 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 27 Feb 2026 13:57:18 +0100 Subject: MeasureLatency now bails out quickly if it experiences a connection error (#789) previously it would stall some 40s in this case --- src/zenhttp/clients/httpclientcpr.cpp | 15 +++++++++++++++ src/zenhttp/httpclient.cpp | 7 +++++++ src/zenhttp/include/zenhttp/httpclient.h | 3 +++ 3 files changed, 25 insertions(+) (limited to 'src') diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 6bc63db09..90dcfacbb 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -23,6 +23,21 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic HttpClientRequestIdCounter{0}; +bool +HttpClient::ErrorContext::IsConnectionError() const +{ + switch (static_cast(ErrorCode)) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return true; + default: + return false; + } +} + // If we want to support different HTTP client implementations then we'll need to make this more abstract HttpClientError::ResponseClass diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 998eb27ea..1cfddb366 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -400,6 +400,13 @@ MeasureLatency(HttpClient& Client, std::string_view Url) else { ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); + + // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. + // Bail out immediately — retrying will just burn the connect timeout each time. + if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) + { + break; + } } } diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 41a7ce621..f00bbec63 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -125,6 +125,9 @@ public: { int ErrorCode; std::string ErrorMessage; + + /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ + bool IsConnectionError() const; }; struct KeyValueMap -- cgit v1.2.3 From 87aff23c1246abd2838d8b0e589fe61015effa9c Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 27 Feb 2026 16:39:04 +0100 Subject: optimize string matching (#791) --- src/zenhttp/clients/httpclientcommon.cpp | 12 +++++++----- src/zenhttp/clients/httpclientcommon.h | 26 ++++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 312ca16d2..c016e1c3c 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -425,12 +425,14 @@ namespace detail { return false; } - void MultipartBoundaryParser::InternalParseInput(std::string_view data) + void MultipartBoundaryParser::ParseInput(std::string_view data) { - size_t ScanPos = 0; - while (ScanPos < data.length()) + const char* InputPtr = data.data(); + size_t InputLength = data.length(); + size_t ScanPos = 0; + while (ScanPos < InputLength) { - const char ScanChar = data[ScanPos]; + const char ScanChar = InputPtr[ScanPos]; if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) { if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) @@ -504,7 +506,7 @@ namespace detail { } ScanPos++; } - PayloadOffset += data.length(); + PayloadOffset += InputLength; } std::pair GetHeaderKeyAndValue(std::string_view HeaderString) diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 8bb1e9268..5ed946541 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -158,11 +158,18 @@ namespace detail { IncrementalStringMatcher() {} - IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) {} + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) + { + RawMatchString = MatchString.data(); + } - void Init(std::string&& InMatchString) { MatchString = std::move(InMatchString); } + void Init(std::string&& InMatchString) + { + MatchString = std::move(InMatchString); + RawMatchString = MatchString.data(); + } - void Reset() + inline void Reset() { MatchLength = 0; MatchStartOffset = 0; @@ -186,13 +193,13 @@ namespace detail { void Match(uint64_t Offset, char C) { - ZEN_ASSERT_SLOW(!MatchString.empty()); + ZEN_ASSERT_SLOW(RawMatchString != nullptr); if (MatchState == EMatchState::Complete) { Reset(); } - if (C == MatchString[MatchLength]) + if (C == RawMatchString[MatchLength]) { if (MatchLength == 0) { @@ -222,8 +229,9 @@ namespace detail { private: std::string MatchString; + const char* RawMatchString = nullptr; + uint64_t MatchLength = 0; - uint64_t MatchLength = 0; uint64_t MatchStartOffset = 0; }; @@ -233,8 +241,8 @@ namespace detail { std::vector Boundaries; MultipartBoundaryParser(); - bool Init(const std::string_view ContentTypeHeaderValue); - inline void ParseInput(std::string_view data) { InternalParseInput(data); } + bool Init(const std::string_view ContentTypeHeaderValue); + void ParseInput(std::string_view data); private: IncrementalStringMatcher BoundaryBeginMatcher; @@ -243,8 +251,6 @@ namespace detail { ExtendableStringBuilder<64> BoundaryHeader; uint64_t PayloadOffset = 0; - - void InternalParseInput(std::string_view data); }; std::pair GetHeaderKeyAndValue(std::string_view HeaderString); -- cgit v1.2.3 From 65eefdfe5a216b546f0d3d8fdfc5e9e58916e5f8 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 27 Feb 2026 17:12:00 +0100 Subject: add sentry-sdk logger (#793) eliminates spurious sentry log output during startup as the new channel defaults to WARN The level can be overridden via `--log-debug=sentry-sdk` or `--log-info=sentry-sdk` --- src/zencore/sentryintegration.cpp | 62 ++++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 00e67dc85..636e182b4 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -145,6 +145,8 @@ SentryAssertImpl::OnAssert(const char* Filename, namespace zen { # if ZEN_USE_SENTRY +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogSentry, "sentry-sdk"); + static void SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) { @@ -163,26 +165,61 @@ SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[may MessagePtr = LogMessage.c_str(); } + // SentryLogFunction can be called before the logging system is initialized + // (during sentry_init which runs before InitializeLogging). Fall back to + // console logging when the category logger is not yet available. + // + // Since we want to default to WARN level but this runs before logging has + // been configured, we ignore the callbacks for DEBUG/INFO explicitly here + // which means users don't see every possible log message if they're trying + // to configure the levels using --log-debug=sentry-sdk + if (!TheDefaultLogger) + { + switch (Level) + { + case SENTRY_LEVEL_DEBUG: + // ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_INFO: + // ZEN_CONSOLE_INFO("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_WARNING: + ZEN_CONSOLE_WARN("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_ERROR: + ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_FATAL: + ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr); + break; + } + return; + } + switch (Level) { case SENTRY_LEVEL_DEBUG: - ZEN_CONSOLE_DEBUG("sentry: {}", MessagePtr); + ZEN_LOG_DEBUG(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_INFO: - ZEN_CONSOLE_INFO("sentry: {}", MessagePtr); + ZEN_LOG_INFO(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_WARNING: - ZEN_CONSOLE_WARN("sentry: {}", MessagePtr); + ZEN_LOG_WARN(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_ERROR: - ZEN_CONSOLE_ERROR("sentry: {}", MessagePtr); + ZEN_LOG_ERROR(LogSentry, "sentry: {}", MessagePtr); break; case SENTRY_LEVEL_FATAL: - ZEN_CONSOLE_CRITICAL("sentry: {}", MessagePtr); + ZEN_LOG_CRITICAL(LogSentry, "sentry: {}", MessagePtr); break; } } @@ -310,22 +347,31 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine void SentryIntegration::LogStartupInformation() { + // Initialize the sentry-sdk log category at Warn level to reduce startup noise. + // The level can be overridden via --log-debug=sentry-sdk or --log-info=sentry-sdk + LogSentry.Logger().SetLogLevel(logging::level::Warn); + if (m_IsInitialized) { if (m_SentryErrorCode == 0) { if (m_AllowPII) { - ZEN_INFO("sentry initialized, username: '{}', hostname: '{}', id: '{}'", m_SentryUserName, m_SentryHostName, m_SentryId); + ZEN_LOG_INFO(LogSentry, + "sentry initialized, username: '{}', hostname: '{}', id: '{}'", + m_SentryUserName, + m_SentryHostName, + m_SentryId); } else { - ZEN_INFO("sentry initialized with anonymous reports"); + ZEN_LOG_INFO(LogSentry, "sentry initialized with anonymous reports"); } } else { - ZEN_WARN( + ZEN_LOG_WARN( + LogSentry, "sentry_init returned failure! (error code: {}) note that sentry expects crashpad_handler to exist alongside the running " "executable", m_SentryErrorCode); -- cgit v1.2.3 From 0a41fd42aa43080fbc991e7d976dde70aeaec594 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 27 Feb 2026 17:13:40 +0100 Subject: add full WebSocket (RFC 6455) client/server support for zenhttp (#792) * This branch adds full WebSocket (RFC 6455) support to the HTTP server layer, covering both transport backends, a client, and tests. - **`websocket.h`** -- Core interfaces: `WebSocketOpcode`, `WebSocketMessage`, `WebSocketConnection` (ref-counted), and `IWebSocketHandler`. Services opt in to WebSocket support by implementing `IWebSocketHandler` alongside their existing `HttpService`. - **`httpwsclient.h`** -- `HttpWsClient`: an ASIO-backed `ws://` client with both standalone (own thread) and shared `io_context` modes. Supports connect timeout and optional auth token injection via `IWsClientHandler` callbacks. - **`wsasio.cpp/h`** -- `WsAsioConnection`: WebSocket over ASIO TCP. Takes over the socket after the HTTP 101 handshake and runs an async read/write loop with a queued write path (guarded by `RwLock`). - **`wshttpsys.cpp/h`** -- `WsHttpSysConnection`: WebSocket over http.sys opaque-mode connections (Windows only). Uses `HttpReceiveRequestEntityBody` / `HttpSendResponseEntityBody` via IOCP, sharing the same threadpool as normal http.sys traffic. Self-ref lifetime management ensures graceful drain of outstanding async ops. - **`httpsys_iocontext.h`** -- Tagged `OVERLAPPED` wrapper (`HttpSysIoContext`) used to distinguish normal HTTP transactions from WebSocket read/write completions in the single IOCP callback. - **`wsframecodec.cpp/h`** -- `WsFrameCodec`: static helpers for parsing (unmasked and masked) and building (unmasked server frames and masked client frames) RFC 6455 frames across all three payload length encodings (7-bit, 16-bit, 64-bit). Also computes `Sec-WebSocket-Accept` keys. - **`clients/httpwsclient.cpp`** -- `HttpWsClient::Impl`: ASIO-based client that performs the HTTP upgrade handshake, then hands off to the frame codec for the read loop. Manages its own `io_context` thread or plugs into an external one. - **`httpasio.cpp`** -- ASIO server now detects `Upgrade: websocket` requests, checks the matching `HttpService` for `IWebSocketHandler` via `dynamic_cast`, performs the RFC 6455 handshake (101 response), and spins up a `WsAsioConnection`. - **`httpsys.cpp`** -- Same upgrade detection and handshake logic for the http.sys backend, using `WsHttpSysConnection` and `HTTP_SEND_RESPONSE_FLAG_OPAQUE`. - **`httpparser.cpp/h`** -- Extended to surface the `Upgrade` / `Connection` / `Sec-WebSocket-Key` headers needed by the handshake. - **`httpcommon.h`** -- Minor additions (probably new header constants or response codes for the WS upgrade). - **`httpserver.h`** -- Small interface changes to support WebSocket registration. - **`zenhttp.cpp` / `xmake.lua`** -- New source files wired in; build config updated. - **Unit tests** (`websocket.framecodec`): round-trip encode/decode for text, binary, close frames; all three payload sizes; masked and unmasked variants; RFC 6455 `Sec-WebSocket-Accept` test vector. - **Integration tests** (`websocket.integration`): full ASIO server tests covering handshake (101), normal HTTP coexistence, echo, server-push broadcast, client close handshake, ping/pong auto-response, sequential messages, and rejection of upgrades on non-WS services. - **Client tests** (`websocket.client`): `HttpWsClient` connect+echo+close, connection failure (bad port -> close code 1006), and server-initiated close. * changed HttpRequestParser::ParseCurrentHeader to use switch instead of if/else chain * remove spurious printf --------- Co-authored-by: Stefan Boberg --- src/zencore/filesystem.cpp | 1 - src/zenhttp/clients/httpwsclient.cpp | 568 ++++++++++++++++++ src/zenhttp/include/zenhttp/httpcommon.h | 7 + src/zenhttp/include/zenhttp/httpserver.h | 3 +- src/zenhttp/include/zenhttp/httpwsclient.h | 79 +++ src/zenhttp/include/zenhttp/websocket.h | 65 ++ src/zenhttp/servers/httpasio.cpp | 49 ++ src/zenhttp/servers/httpparser.cpp | 148 +++-- src/zenhttp/servers/httpparser.h | 7 + src/zenhttp/servers/httpsys.cpp | 180 ++++-- src/zenhttp/servers/httpsys_iocontext.h | 40 ++ src/zenhttp/servers/wsasio.cpp | 297 ++++++++++ src/zenhttp/servers/wsasio.h | 71 +++ src/zenhttp/servers/wsframecodec.cpp | 229 +++++++ src/zenhttp/servers/wsframecodec.h | 74 +++ src/zenhttp/servers/wshttpsys.cpp | 466 +++++++++++++++ src/zenhttp/servers/wshttpsys.h | 104 ++++ src/zenhttp/servers/wstest.cpp | 922 +++++++++++++++++++++++++++++ src/zenhttp/xmake.lua | 1 + src/zenhttp/zenhttp.cpp | 1 + 20 files changed, 3203 insertions(+), 109 deletions(-) create mode 100644 src/zenhttp/clients/httpwsclient.cpp create mode 100644 src/zenhttp/include/zenhttp/httpwsclient.h create mode 100644 src/zenhttp/include/zenhttp/websocket.h create mode 100644 src/zenhttp/servers/httpsys_iocontext.h create mode 100644 src/zenhttp/servers/wsasio.cpp create mode 100644 src/zenhttp/servers/wsasio.h create mode 100644 src/zenhttp/servers/wsframecodec.cpp create mode 100644 src/zenhttp/servers/wsframecodec.h create mode 100644 src/zenhttp/servers/wshttpsys.cpp create mode 100644 src/zenhttp/servers/wshttpsys.h create mode 100644 src/zenhttp/servers/wstest.cpp (limited to 'src') diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 553897407..03398860b 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3533,7 +3533,6 @@ TEST_CASE("PathBuilder") Path.Reset(); Path.Append(fspath(L"/\u0119oo/")); Path /= L"bar"; - printf("%ls\n", Path.ToPath().c_str()); CHECK(Path.ToView() == L"/\u0119oo/bar"); CHECK(Path.ToPath() == L"\\\u0119oo\\bar"); # endif diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..36a6f081b --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,568 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include "../servers/wsframecodec.h" + +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + if (m_Socket) + { + asio::error_code Ec; + m_Socket->close(Ec); + } + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard = std::make_unique(m_IoContext); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + + void DoResolve() + { + m_Resolver = std::make_unique(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_Socket = std::make_unique(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + if (m_Socket) + { + asio::error_code CloseEc; + m_Socket->close(CloseEc); + } + } + }); + + asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared(ReqStr); + + asio::async_write(*m_Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + } + + void DoReadHandshakeResponse() + { + asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr m_OwnedIoContext; + asio::io_context& m_IoContext; + std::unique_ptr m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr m_Resolver; + std::unique_ptr m_Socket; + std::unique_ptr m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic m_IsOpen{false}; + bool m_CloseSent = false; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index bc18549c9..8fca35ac5 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept return IsHttpSuccessCode(int(HttpCode)); } +[[nodiscard]] inline bool +IsHttpOk(HttpResponseCode HttpCode) noexcept +{ + return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted || + HttpCode == HttpResponseCode::NoContent; +} + std::string_view ToString(HttpResponseCode HttpCode); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 00cbc6c14..fee932daa 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -462,6 +462,7 @@ struct IHttpStatsService virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; }; -void http_forcelink(); // internal +void http_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h new file mode 100644 index 000000000..926ec1e3d --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace zen { + +/** + * Callback interface for WebSocket client events + * + * Separate from the server-side IWebSocketHandler because the caller + * already owns the HttpWsClient — no Ref needed. + */ +class IWsClientHandler +{ +public: + virtual ~IWsClientHandler() = default; + + virtual void OnWsOpen() = 0; + virtual void OnWsMessage(const WebSocketMessage& Msg) = 0; + virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0; +}; + +struct HttpWsClientSettings +{ + std::string LogCategory = "wsclient"; + std::chrono::milliseconds ConnectTimeout{5000}; + std::optional> AccessTokenProvider; +}; + +/** + * WebSocket client over TCP (ws:// scheme) + * + * Uses ASIO for async I/O. Two construction modes: + * - Internal io_context + background thread (standalone use) + * - External io_context (shared event loop, no internal thread) + * + * Thread-safe for SendText/SendBinary/Close. + */ +class HttpWsClient +{ +public: + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {}); + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {}); + + ~HttpWsClient(); + + HttpWsClient(const HttpWsClient&) = delete; + HttpWsClient& operator=(const HttpWsClient&) = delete; + + void Connect(); + void SendText(std::string_view Text); + void SendBinary(std::span Data); + void Close(uint16_t Code = 1000, std::string_view Reason = {}); + bool IsOpen() const; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..7a6fb33dd --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include +#include +#include + +namespace zen { + +enum class WebSocketOpcode : uint8_t +{ + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA +}; + +struct WebSocketMessage +{ + WebSocketOpcode Opcode; + IoBuffer Payload; + uint16_t CloseCode = 0; +}; + +/** + * Represents an active WebSocket connection + * + * Derived classes implement the actual transport (e.g. ASIO sockets). + * Instances are reference-counted so that both the service layer and + * the async read/write loop can share ownership. + */ +class WebSocketConnection : public RefCounted +{ +public: + virtual ~WebSocketConnection() = default; + + virtual void SendText(std::string_view Text) = 0; + virtual void SendBinary(std::span Data) = 0; + virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; + virtual bool IsOpen() const = 0; +}; + +/** + * Interface for services that accept WebSocket upgrades + * + * An HttpService may additionally implement this interface to indicate + * it supports WebSocket connections. The HTTP server checks for this + * via dynamic_cast when it sees an Upgrade: websocket request. + */ +class IWebSocketHandler +{ +public: + virtual ~IWebSocketHandler() = default; + + virtual void OnWebSocketOpen(Ref Connection) = 0; + virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; + virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 0c0238886..8c2dcd116 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -14,6 +14,8 @@ #include #include "httpparser.h" +#include "wsasio.h" +#include "wsframecodec.h" #include @@ -1159,6 +1161,53 @@ HttpServerConnection::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); + // WebSocket upgrade detection must happen before the keep-alive check below, + // because Upgrade requests have "Connection: Upgrade" which the HTTP parser + // treats as non-keep-alive, causing a premature shutdown of the receive side. + if (m_RequestData.IsWebSocketUpgrade()) + { + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + IWebSocketHandler* WsHandler = dynamic_cast(Service); + if (WsHandler && !m_RequestData.SecWebSocketKey().empty()) + { + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey()); + + auto ResponseStr = std::make_shared(); + ResponseStr->reserve(256); + ResponseStr->append( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "); + ResponseStr->append(AcceptKey); + ResponseStr->append("\r\n\r\n"); + + // Send the 101 response on the current socket, then hand the socket off + // to a WsAsioConnection for the WebSocket protocol. + asio::async_write(*m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Ref WsConn(new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler)); + Ref WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); + + m_RequestState = RequestState::kDone; + return; + } + } + // Service doesn't support WebSocket or missing key — fall through to normal handling + } + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index f0485aa25..3b1229375 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -12,14 +12,17 @@ namespace zen { using namespace std::literals; -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); -static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); +static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); +static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // @@ -143,45 +146,62 @@ HttpRequestParser::ParseCurrentHeader() const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); - if (HeaderHash == HashContentLength) + switch (HeaderHash) { - m_ContentLengthHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAuthorization) - { - m_AuthorizationHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::TryFromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = CurrentHeaderIndex; + case HashContentLength: + m_ContentLengthHeaderIndex = CurrentHeaderIndex; + break; + + case HashAccept: + m_AcceptHeaderIndex = CurrentHeaderIndex; + break; + + case HashContentType: + m_ContentTypeHeaderIndex = CurrentHeaderIndex; + break; + + case HashAuthorization: + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + break; + + case HashSession: + m_SessionId = Oid::TryFromHexString(HeaderValue); + break; + + case HashRequest: + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + break; + + case HashExpect: + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + break; + + case HashRange: + m_RangeHeaderIndex = CurrentHeaderIndex; + break; + + case HashUpgrade: + m_UpgradeHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketKey: + m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketVersion: + m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; + break; + + default: + break; } } @@ -361,14 +381,17 @@ HttpRequestParser::ResetState() m_HeaderEntries.clear(); - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_AuthorizationHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; + m_UpgradeHeaderIndex = -1; + m_SecWebSocketKeyHeaderIndex = -1; + m_SecWebSocketVersionHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); @@ -425,4 +448,21 @@ HttpRequestParser::OnMessageComplete() } } +bool +HttpRequestParser::IsWebSocketUpgrade() const +{ + std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); + if (Upgrade.empty()) + { + return false; + } + + // Case-insensitive check for "websocket" + if (Upgrade.size() != 9) + { + return false; + } + return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; +} + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index ff56ca970..d40a5aeb0 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -48,6 +48,10 @@ struct HttpRequestParser std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); } + std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); } + bool IsWebSocketUpgrade() const; + private: struct HeaderRange { @@ -86,6 +90,9 @@ private: int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; int8_t m_AuthorizationHeaderIndex; + int8_t m_UpgradeHeaderIndex; + int8_t m_SecWebSocketKeyHeaderIndex; + int8_t m_SecWebSocketVersionHeaderIndex; HttpVerb m_RequestVerb; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index e93ae4853..23d57af57 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -156,6 +156,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include # include # pragma comment(lib, "httpapi.lib") @@ -380,7 +384,7 @@ public: PTP_IO Iocp(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -397,8 +401,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -1555,7 +1559,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { @@ -2111,64 +2131,118 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector UnknownHeaders; + + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Ref WsConn( + new WsHttpSysConnection(RequestQueueHandle, RequestId, *WsHandler, Transaction().Iocp())); + Ref WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } - ZEN_INFO(""); + ZEN_WARN("WebSocket 101 send failed: {}", SendResult); + + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h new file mode 100644 index 000000000..4f8a97012 --- /dev/null +++ b/src/zenhttp/servers/httpsys_iocontext.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include + +# include + +namespace zen { + +/** + * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch + * + * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection + * (for WebSocket read/write) embed this struct. The single IoCompletionCallback + * bound to the request queue uses the ContextType tag to dispatch to the correct + * handler. + * + * The Overlapped member must be first so that CONTAINING_RECORD works to recover + * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool. + */ +struct HttpSysIoContext +{ + OVERLAPPED Overlapped{}; + + enum class Type : uint8_t + { + kTransaction, + kWebSocketRead, + kWebSocketWrite, + } ContextType = Type::kTransaction; + + void* Owner = nullptr; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp new file mode 100644 index 000000000..dfc1eac38 --- /dev/null +++ b/src/zenhttp/servers/wsasio.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsasio.h" +#include "wsframecodec.h" + +#include + +namespace zen::asio_http { + +static LoggerRef +WsLog() +{ + static LoggerRef g_Logger = logging::Get("ws"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsAsioConnection::WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler) +: m_Socket(std::move(Socket)) +, m_Handler(Handler) +{ +} + +WsAsioConnection::~WsAsioConnection() +{ + m_IsOpen.store(false); +} + +void +WsAsioConnection::Start() +{ + EnqueueRead(); +} + +bool +WsAsioConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Read loop +// + +void +WsAsioConnection::EnqueueRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + Ref Self(this); + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { + Self->OnDataReceived(Ec, ByteCount); + }); +} + +void +WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } +} + +void +WsAsioConnection::ProcessReceivedData() +{ + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* Data = static_cast(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Shut down the socket + std::error_code ShutdownEc; + m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); + m_Socket->close(ShutdownEc); + return; + } + + default: + ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Write queue +// + +void +WsAsioConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::SendBinary(std::span Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent) + { + m_CloseSent = true; + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); +} + +void +WsAsioConnection::EnqueueWrite(std::vector Frame) +{ + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsAsioConnection::FlushWriteQueue() +{ + std::vector Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + Ref Self(this); + + // Move Frame into a shared_ptr so we can create the buffer and capture ownership + // in the same async_write call without evaluation order issues. + auto OwnedFrame = std::make_shared>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); +} + +void +WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + return; + } + + FlushWriteQueue(); +} + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h new file mode 100644 index 000000000..a638ea836 --- /dev/null +++ b/src/zenhttp/servers/wsasio.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include +#include +#include + +namespace zen::asio_http { + +/** + * WebSocket connection over an ASIO TCP socket + * + * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * and runs an async read/write loop to exchange WebSocket frames. + * + * Lifetime is managed solely through intrusive reference counting (RefCounted). + * The async read/write callbacks capture Ref to keep the + * connection alive for the duration of the async operation. The service layer + * also holds a Ref. + */ +class WsAsioConnection : public WebSocketConnection +{ +public: + WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler); + ~WsAsioConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and the 101 response has been sent. + */ + void Start(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void ProcessReceivedData(); + + void EnqueueWrite(std::vector Frame); + void FlushWriteQueue(); + void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount); + + void DoClose(uint16_t Code, std::string_view Reason); + + std::unique_ptr m_Socket; + IWebSocketHandler& m_Handler; + asio::streambuf m_ReadBuffer; + + RwLock m_WriteLock; + std::deque> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic m_IsOpen{true}; + bool m_CloseSent = false; +}; + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..a4c5e0f16 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,229 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include +#include + +#include +#include + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span Payload) +{ + std::vector Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) +{ + std::vector Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h new file mode 100644 index 000000000..2d90b6fa1 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace zen { + +/** + * Result of attempting to parse a single WebSocket frame from a byte buffer + */ +struct WsFrameParseResult +{ + bool IsValid = false; // true if a complete frame was successfully parsed + size_t BytesConsumed = 0; // number of bytes consumed from the input buffer + WebSocketOpcode Opcode = WebSocketOpcode::kText; + bool Fin = false; + std::vector Payload; +}; + +/** + * RFC 6455 WebSocket frame codec + * + * Provides static helpers for parsing client-to-server frames (which are + * always masked) and building server-to-client frames (which are never masked). + */ +struct WsFrameCodec +{ + /** + * Try to parse one complete frame from the front of the buffer. + * + * Returns a result with IsValid == false and BytesConsumed == 0 when + * there is not enough data yet. The caller should accumulate more data + * and retry. + */ + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + + /** + * Build a server-to-client frame (no masking) + */ + static std::vector BuildFrame(WebSocketOpcode Opcode, std::span Payload); + + /** + * Build a close frame with a status code and optional reason string + */ + static std::vector BuildCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Build a client-to-server frame (with masking per RFC 6455) + */ + static std::vector BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload); + + /** + * Build a masked close frame with status code and optional reason + */ + static std::vector BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2 + * + * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + */ + static std::string ComputeAcceptKey(std::string_view ClientKey); +}; + +} // namespace zen diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp new file mode 100644 index 000000000..3f0f0b447 --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -0,0 +1,466 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wshttpsys.h" + +#if ZEN_WITH_HTTPSYS + +# include "wsframecodec.h" + +# include + +namespace zen { + +static LoggerRef +WsHttpSysLog() +{ + static LoggerRef g_Logger = logging::Get("ws_httpsys"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp) +: m_RequestQueueHandle(RequestQueueHandle) +, m_RequestId(RequestId) +, m_Handler(Handler) +, m_Iocp(Iocp) +, m_ReadBuffer(8192) +{ + m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; + m_ReadIoContext.Owner = this; + m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; + m_WriteIoContext.Owner = this; +} + +WsHttpSysConnection::~WsHttpSysConnection() +{ + ZEN_ASSERT(m_OutstandingOps.load() == 0); + + if (m_IsOpen.exchange(false)) + { + Disconnect(); + } +} + +void +WsHttpSysConnection::Start() +{ + m_SelfRef = Ref(this); + IssueAsyncRead(); +} + +void +WsHttpSysConnection::Shutdown() +{ + m_ShutdownRequested.store(true, std::memory_order_relaxed); + + if (!m_IsOpen.exchange(false)) + { + return; + } + + // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +bool +WsHttpSysConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Async read path +// + +void +WsHttpSysConnection::IssueAsyncRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) + { + MaybeReleaseSelfRef(); + return; + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, + m_RequestId, + 0, // Flags + m_ReadBuffer.data(), + (ULONG)m_ReadBuffer.size(), + nullptr, // BytesRead (ignored for async) + &m_ReadIoContext.Overlapped); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef + Ref Guard(this); + + if (IoResult != NO_ERROR) + { + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + if (IoResult == ERROR_HANDLE_EOF) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); + } + else if (IoResult != ERROR_OPERATION_ABORTED) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + } + + MaybeReleaseSelfRef(); + return; + } + + if (NumberOfBytesTransferred > 0) + { + m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); + ProcessReceivedData(); + } + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + IssueAsyncRead(); + } + else + { + MaybeReleaseSelfRef(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +void +WsHttpSysConnection::ProcessReceivedData() +{ + while (!m_Accumulated.empty()) + { + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + // Remove consumed bytes + m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent) + { + m_CloseSent = true; + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + Disconnect(); + return; + } + + default: + ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Async write path +// + +void +WsHttpSysConnection::EnqueueWrite(std::vector Frame) +{ + bool ShouldFlush = false; + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.push_back(std::move(Frame)); + + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + } + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsHttpSysConnection::FlushWriteQueue() +{ + { + RwLock::ExclusiveLockScope _(m_WriteLock); + + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + + m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); + m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; + m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); + m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); + + ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA, + 1, + &m_WriteChunk, + nullptr, + nullptr, + 0, + &m_WriteIoContext.Overlapped, + nullptr); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + m_CurrentWriteBuffer.clear(); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + // Hold a transient ref to prevent mid-callback destruction + Ref Guard(this); + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + m_CurrentWriteBuffer.clear(); + + if (IoResult != NO_ERROR) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Send interface +// + +void +WsHttpSysConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::SendBinary(std::span Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent) + { + m_CloseSent = true; + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// +// Lifetime management +// + +void +WsHttpSysConnection::MaybeReleaseSelfRef() +{ + if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) + { + m_SelfRef = nullptr; + } +} + +void +WsHttpSysConnection::Disconnect() +{ + // Send final empty body with DISCONNECT to tell http.sys the connection is done + HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT, + 0, + nullptr, + nullptr, + nullptr, + 0, + nullptr, + nullptr); +} + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h new file mode 100644 index 000000000..ab0ca381a --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.h @@ -0,0 +1,104 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include "httpsys_iocontext.h" + +#include + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include +# include + +# include +# include +# include + +namespace zen { + +/** + * WebSocket connection over an http.sys opaque-mode connection + * + * After the 101 Switching Protocols response is sent with + * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the + * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody / + * HttpSendResponseEntityBody using the original RequestId. + * + * All I/O is performed asynchronously via the same IOCP threadpool used + * for normal http.sys traffic, eliminating per-connection threads. + * + * Lifetime is managed through intrusive reference counting (RefCounted). + * A self-reference (m_SelfRef) is held from Start() until all outstanding + * async operations have drained, preventing premature destruction. + */ +class WsHttpSysConnection : public WebSocketConnection +{ +public: + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp); + ~WsHttpSysConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and after the 101 response has been sent. + */ + void Start(); + + /** + * Shut down the connection. Cancels pending I/O; IOCP completions + * will fire with ERROR_OPERATION_ABORTED and drain naturally. + */ + void Shutdown(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + + // Called from IoCompletionCallback via tagged dispatch + void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + +private: + void IssueAsyncRead(); + void ProcessReceivedData(); + void EnqueueWrite(std::vector Frame); + void FlushWriteQueue(); + void DoClose(uint16_t Code, std::string_view Reason); + void Disconnect(); + void MaybeReleaseSelfRef(); + + HANDLE m_RequestQueueHandle; + HTTP_REQUEST_ID m_RequestId; + IWebSocketHandler& m_Handler; + PTP_IO m_Iocp; + + // Tagged OVERLAPPED contexts for concurrent read and write + HttpSysIoContext m_ReadIoContext{}; + HttpSysIoContext m_WriteIoContext{}; + + // Read state + std::vector m_ReadBuffer; + std::vector m_Accumulated; + + // Write state + RwLock m_WriteLock; + std::deque> m_WriteQueue; + std::vector m_CurrentWriteBuffer; + HTTP_DATA_CHUNK m_WriteChunk{}; + bool m_IsWriting = false; + + // Lifetime management + std::atomic m_OutstandingOps{0}; + Ref m_SelfRef; + std::atomic m_ShutdownRequested{false}; + std::atomic m_IsOpen{true}; + bool m_CloseSent = false; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp new file mode 100644 index 000000000..95f8587df --- /dev/null +++ b/src/zenhttp/servers/wstest.cpp @@ -0,0 +1,922 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS + +# include +# include +# include + +# include +# include +# include + +# include "httpasio.h" +# include "wsframecodec.h" + +ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_PLATFORM_WINDOWS +# include +# else +# include +# include +# endif +# include +ZEN_THIRD_PARTY_INCLUDES_END + +# include +# include +# include +# include +# include +# include +# include +# include + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests: WsFrameCodec +// + +TEST_CASE("websocket.framecodec") +{ + SUBCASE("ComputeAcceptKey RFC 6455 test vector") + { + // RFC 6455 section 4.2.2 example + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - text") + { + std::string_view Text = "Hello, WebSocket!"; + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + // Server frames are unmasked — TryParseFrame should handle them + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") + { + std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildFrame - medium payload (126-65535 bytes)") + { + std::vector Payload(300, 0x42); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildFrame - large payload (>65535 bytes)") + { + std::vector Payload(70000, 0xAB); + + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildCloseFrame roundtrip") + { + std::vector Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } + + SUBCASE("TryParseFrame - partial data returns invalid") + { + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); + + // Pass only 1 byte — not enough for a frame header + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); + CHECK_FALSE(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, 0u); + } + + SUBCASE("TryParseFrame - empty payload") + { + std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK(Result.Payload.empty()); + } + + SUBCASE("TryParseFrame - masked client frame") + { + // Build a masked frame manually as a client would send + // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" + uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; + uint8_t MaskedPayload[5] = {}; + const char* Original = "Hello"; + for (int i = 0; i < 5; ++i) + { + MaskedPayload[i] = static_cast(Original[i]) ^ MaskKey[i % 4]; + } + + std::vector Frame; + Frame.push_back(0x81); // FIN + text + Frame.push_back(0x85); // MASK + len=5 + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), 5u); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), 5), "Hello"sv); + } + + SUBCASE("BuildMaskedFrame roundtrip - text") + { + std::string_view Text = "Hello, masked WebSocket!"; + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + // Verify mask bit is set + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildMaskedFrame roundtrip - binary") + { + std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") + { + std::vector Payload(300, 0x42); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") + { + std::vector Payload(70000, 0xAB); + + std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildMaskedCloseFrame roundtrip") + { + std::vector Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: WebSocket over ASIO +// + +namespace { + + /** + * Helper: Build a masked client-to-server frame per RFC 6455 + */ + std::vector BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) + { + std::vector Frame; + + // FIN + opcode + Frame.push_back(0x80 | static_cast(Opcode)); + + // Payload length with mask bit set + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + // Mask key (use a fixed key for deterministic tests) + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + + std::vector BuildMaskedTextFrame(std::string_view Text) + { + std::span Payload(reinterpret_cast(Text.data()), Text.size()); + return BuildMaskedFrame(WebSocketOpcode::kText, Payload); + } + + std::vector BuildMaskedCloseFrame(uint16_t Code) + { + std::vector Payload; + Payload.push_back(static_cast((Code >> 8) & 0xFF)); + Payload.push_back(static_cast(Code & 0xFF)); + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); + } + + /** + * Test service that implements IWebSocketHandler + */ + struct WsTestService : public HttpService, public IWebSocketHandler + { + const char* BaseUri() const override { return "/wstest/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref Connection) override + { + m_OpenCount.fetch_add(1); + + m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); + } + + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + + // Echo the message back + Conn.SendText(Text); + } + } + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + + m_ConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref& C) { + return C.Get() == &Conn; + }); + m_Connections.erase(It, m_Connections.end()); + }); + } + + void SendToAll(std::string_view Text) + { + RwLock::SharedLockScope _(m_ConnectionsLock); + for (auto& Conn : m_Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Text); + } + } + } + + std::atomic m_OpenCount{0}; + std::atomic m_MessageCount{0}; + std::atomic m_CloseCount{0}; + std::atomic m_LastCloseCode{0}; + std::string m_LastMessage; + + RwLock m_ConnectionsLock; + std::vector> m_Connections; + }; + + /** + * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket + * + * Returns true on success (101 response), false otherwise. + */ + bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) + { + // Send HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << Path << " HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + // Read the response (look for "101") + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + return Response.find("101") != std::string::npos; + } + + /** + * Helper: Read a single server-to-client frame from a socket + * + * Uses a background thread with a synchronous ASIO read and a timeout. + */ + WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) + { + std::vector Buffer; + WsFrameParseResult Result; + std::atomic Done{false}; + + std::thread Reader([&] { + while (!Done.load()) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (Frame.IsValid) + { + Result = std::move(Frame); + Done.store(true); + return; + } + } + }); + + auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); + while (!Done.load() && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + if (!Done.load()) + { + // Timeout — cancel the read + asio::error_code Ec; + Sock.cancel(Ec); + } + + if (Reader.joinable()) + { + Reader.join(); + } + + return Result; + } + +} // anonymous namespace + +TEST_CASE("websocket.integration") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + // Give server a moment to start accepting + Sleep(100); + + SUBCASE("handshake succeeds with 101") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + CHECK(Ok); + + Sleep(50); + CHECK_EQ(TestService.m_OpenCount.load(), 1); + + Sock.close(); + } + + SUBCASE("normal HTTP still works alongside WebSocket service") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + // Send a normal HTTP GET (not upgrade) + std::string HttpReq = fmt::format( + "GET /wstest/hello HTTP/1.1\r\n" + "Host: 127.0.0.1:{}\r\n" + "Connection: close\r\n" + "\r\n", + Port); + + asio::write(Sock, asio::buffer(HttpReq)); + + asio::streambuf ResponseBuf; + asio::error_code Ec; + asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + CHECK(Response.find("200") != std::string::npos); + } + + SUBCASE("echo message roundtrip") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a text message (masked, as client) + std::vector Frame = BuildMaskedTextFrame("ping test"); + asio::write(Sock, asio::buffer(Frame)); + + // Read the echo reply + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "ping test"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "ping test"); + + Sock.close(); + } + + SUBCASE("server push to client") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Server pushes a message + TestService.SendToAll("server says hello"); + + WsFrameParseResult Frame = ReadOneFrame(Sock); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "server says hello"sv); + + Sock.close(); + } + + SUBCASE("client close handshake") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send close frame + std::vector CloseFrame = BuildMaskedCloseFrame(1000); + asio::write(Sock, asio::buffer(CloseFrame)); + + // Server should echo close back + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + + Sleep(50); + CHECK_EQ(TestService.m_CloseCount.load(), 1); + CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); + + Sock.close(); + } + + SUBCASE("multiple concurrent connections") + { + constexpr int NumClients = 5; + + asio::io_context IoCtx; + std::vector Sockets; + + for (int i = 0; i < NumClients; ++i) + { + Sockets.emplace_back(IoCtx); + Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); + REQUIRE(Ok); + } + + Sleep(100); + CHECK_EQ(TestService.m_OpenCount.load(), NumClients); + + // Broadcast from server + TestService.SendToAll("broadcast"); + + // Each client should receive the message + for (int i = 0; i < NumClients; ++i) + { + WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "broadcast"sv); + } + + // Close all + for (auto& S : Sockets) + { + S.close(); + } + } + + SUBCASE("service without IWebSocketHandler rejects upgrade") + { + // Register a plain HTTP service (no WebSocket) + struct PlainService : public HttpService + { + const char* BaseUri() const override { return "/plain/"; } + void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } + }; + + PlainService Plain; + Server->RegisterService(Plain); + + Sleep(50); + + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + // Attempt WebSocket upgrade on the plain service + ExtendableStringBuilder<512> Request; + Request << "GET /plain/ws HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + // Should NOT get 101 — should fall through to normal request handling + CHECK(Response.find("101") == std::string::npos); + + Sock.close(); + } + + SUBCASE("ping/pong auto-response") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a ping frame with payload "test" + std::string_view PingPayload = "test"; + std::span PingData(reinterpret_cast(PingPayload.data()), PingPayload.size()); + std::vector PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); + asio::write(Sock, asio::buffer(PingFrame)); + + // Should receive a pong with the same payload + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); + CHECK_EQ(Reply.Payload.size(), 4u); + std::string_view PongText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(PongText, "test"sv); + + Sock.close(); + } + + SUBCASE("multiple messages in sequence") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + for (int i = 0; i < 10; ++i) + { + std::string Msg = fmt::format("message {}", i); + std::vector Frame = BuildMaskedTextFrame(Msg); + asio::write(Sock, asio::buffer(Frame)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, Msg); + } + + CHECK_EQ(TestService.m_MessageCount.load(), 10); + + Sock.close(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: HttpWsClient +// + +namespace { + + struct TestWsClientHandler : public IWsClientHandler + { + void OnWsOpen() override { m_OpenCount.fetch_add(1); } + + void OnWsMessage(const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + } + } + + void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + } + + std::atomic m_OpenCount{0}; + std::atomic m_MessageCount{0}; + std::atomic m_CloseCount{0}; + std::atomic m_LastCloseCode{0}; + std::string m_LastMessage; + }; + +} // anonymous namespace + +TEST_CASE("websocket.client") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7576, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello from client"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello from client"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + // The server echoes the close frame, which triggers OnWsClose on the client side + // with the server's close code. Allow the connection to settle. + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } + + SUBCASE("connect to bad port") + { + TestWsClientHandler Handler; + std::string Url = "ws://127.0.0.1:1/wstest/ws"; + + HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); + CHECK_EQ(Handler.m_OpenCount.load(), 0); + } + + SUBCASE("server-initiated close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + + // Copy connections then close them outside the lock to avoid deadlocking + // with OnWebSocketClose which acquires an exclusive lock + std::vector> Conns; + TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); + for (auto& Conn : Conns) + { + Conn->Close(1001, "going away"); + } + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); + CHECK_FALSE(Client.IsOpen()); + } +} + +void +websocket_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_WITH_TESTS diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 78876d21b..e8f87b668 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -6,6 +6,7 @@ target('zenhttp') add_headerfiles("**.h") add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) + add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") add_packages("http_parser", "json11") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index ad14ecb8d..3ac8eea8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -19,6 +19,7 @@ zenhttp_forcelinktests() httpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); + websocket_forcelink(); } } // namespace zen -- cgit v1.2.3 From 1ed3139e577f6c8aa6d07f7e76afa3a80d9d4852 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 27 Feb 2026 19:36:22 +0100 Subject: Add test summary table and failure reporting to xmake test (#794) - Add a summary table printed after all test suites complete, showing per-suite test case counts, assertion counts, timings and pass/fail status. - Add failure reporting: individual failing test cases are listed at the end with their file path and line number for easy navigation. - Made zenserver instances spawned by a hub not create new console windows for a better background testing experience - The TestListener in testing.cpp now writes a machine-readable summary file (via `ZEN_TEST_SUMMARY_FILE` env var) containing aggregate counts and per-test-case failure details. This runs as a doctest listener alongside any active reporter, so it works with both console and JUnit modes. - Tests now run in a deterministic order defined by a single ordered list that also serves as the test name/target mapping, replacing the previous unordered table + separate order list. - The `--run` option now accepts comma-separated values (e.g. `--run=core,http,util`) and validates each name, reporting unknown test names early. - Fix platform detection in `xmake test`: the config command now passes `-p` explicitly, fixing "mingw" misdetection when running from Git Bash on Windows. - Add missing "util" entry to the help text for `--run`. --- src/zencore/testing.cpp | 56 ++++++++++++++++++++++++-- src/zenutil/include/zenutil/zenserverprocess.h | 1 + src/zenutil/zenserverprocess.cpp | 3 +- 3 files changed, 55 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 936424e0f..ef8fb0480 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -5,6 +5,12 @@ #if ZEN_WITH_TESTS +# include +# include +# include +# include +# include + # include namespace zen::testing { @@ -21,9 +27,35 @@ struct TestListener : public doctest::IReporter void report_query(const doctest::QueryData& /*in*/) override {} - void test_run_start() override {} + void test_run_start() override { RunStart = std::chrono::steady_clock::now(); } + + void test_run_end(const doctest::TestRunStats& in) override + { + auto elapsed = std::chrono::steady_clock::now() - RunStart; + double elapsedSeconds = std::chrono::duration_cast(elapsed).count() / 1000.0; - void test_run_end(const doctest::TestRunStats& /*in*/) override {} + // Write machine-readable summary to file if requested (used by xmake test summary table) + const char* summaryFile = std::getenv("ZEN_TEST_SUMMARY_FILE"); + if (summaryFile && summaryFile[0] != '\0') + { + if (FILE* f = std::fopen(summaryFile, "w")) + { + std::fprintf(f, + "cases_total=%u\ncases_passed=%u\nassertions_total=%d\nassertions_passed=%d\n" + "elapsed_seconds=%.3f\n", + in.numTestCasesPassingFilters, + in.numTestCasesPassingFilters - in.numTestCasesFailed, + in.numAsserts, + in.numAsserts - in.numAssertsFailed, + elapsedSeconds); + for (const auto& failure : FailedTests) + { + std::fprintf(f, "failed=%s|%s|%u\n", failure.Name.c_str(), failure.File.c_str(), failure.Line); + } + std::fclose(f); + } + } + } void test_case_start(const doctest::TestCaseData& in) override { @@ -37,7 +69,14 @@ struct TestListener : public doctest::IReporter ZEN_CONSOLE("{}-------------------------------------------------------------------------------{}", ColorYellow, ColorNone); } - void test_case_end(const doctest::CurrentTestCaseStats& /*in*/) override { Current = nullptr; } + void test_case_end(const doctest::CurrentTestCaseStats& in) override + { + if (!in.testCaseSuccess && Current) + { + FailedTests.push_back({Current->m_name, Current->m_file.c_str(), Current->m_line}); + } + Current = nullptr; + } void test_case_exception(const doctest::TestCaseException& /*in*/) override {} @@ -57,7 +96,16 @@ struct TestListener : public doctest::IReporter void test_case_skipped(const doctest::TestCaseData& /*in*/) override {} - const doctest::TestCaseData* Current = nullptr; + const doctest::TestCaseData* Current = nullptr; + std::chrono::steady_clock::time_point RunStart = {}; + + struct FailedTestInfo + { + std::string Name; + std::string File; + unsigned Line; + }; + std::vector FailedTests; }; struct TestRunner::Impl diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index d0402640b..b781a03a9 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -42,6 +42,7 @@ public: std::filesystem::path GetTestRootDir(std::string_view Path); inline bool IsInitialized() const { return m_IsInitialized; } inline bool IsTestEnvironment() const { return m_IsTestInstance; } + inline bool IsHubEnvironment() const { return m_IsHubInstance; } inline std::string_view GetServerClass() const { return m_ServerClass; } inline uint16_t GetNewPortNumber() { return m_NextPortNumber.fetch_add(1); } diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index ef2a4fda5..579ba450a 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -934,7 +934,8 @@ ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerAr CommandLine << " " << AdditionalServerArgs; } - SpawnServerInternal(ChildId, CommandLine, !IsTest, WaitTimeoutMs); + const bool OpenConsole = !IsTest && !m_Env.IsHubEnvironment(); + SpawnServerInternal(ChildId, CommandLine, OpenConsole, WaitTimeoutMs); } void -- cgit v1.2.3 From c32b6042dee8444f4e214f227005a657ec87531e Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 27 Feb 2026 21:22:00 +0100 Subject: add multirange requests to blob store (#795) * add multirange requests to blob store --- src/zenhttp/packageformat.cpp | 2 +- src/zenserver-test/buildstore-tests.cpp | 200 ++++++++++++++++++++- .../storage/buildstore/httpbuildstore.cpp | 114 ++++++++++-- 3 files changed, 295 insertions(+), 21 deletions(-) (limited to 'src') diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 708238224..9a80d07c8 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -581,7 +581,7 @@ ParsePackageMessage(IoBuffer Payload, std::function CompressedBlobsHashes; + std::vector CompressedBlobsHashes; + std::vector CompressedBlobsSizes; { ZenServerInstance Instance(TestEnv); @@ -51,6 +52,7 @@ TEST_CASE("buildstore.blobs") IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7); CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); Payload.SetContentType(ZenContentType::kCompressedBinary); @@ -107,6 +109,7 @@ TEST_CASE("buildstore.blobs") IoBuffer Blob = CreateSemiRandomBlob(5713 + I * 7); CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); Payload.SetContentType(ZenContentType::kCompressedBinary); @@ -141,6 +144,201 @@ TEST_CASE("buildstore.blobs") CHECK(IoHash::HashBuffer(Decompressed) == RawHash); } } + + { + // Single-range Get + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}}; + + uint64_t RangeSizeSum = Ranges.front().second; + + HttpClient::KeyValueMap Headers; + + Headers.Entries.insert( + {"Range", fmt::format("bytes={}-{}", Ranges.front().first, Ranges.front().first + Ranges.front().second - 1)}); + + HttpClient::Response Result = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), Headers); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + CHECK_EQ(RangeSizeSum, Payload.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Ranges.front().first, Ranges.front().second); + MemoryView RangeView = Payload.GetView(); + CHECK(ActualRange.EqualBytes(RangeView)); + } + } + + { + // Single-range Post + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + uint64_t RangeSizeSum = 0; + + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector> Ranges = {{BlobSize / 16 * 1, BlobSize / 2}}; + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + RangeSizeSum += Range.second; + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage); + + CbPackage ResponsePackage = ParsePackageMessage(Payload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + CHECK_EQ(RangeArray.Num(), Ranges.size()); + size_t RangeOffset = 0; + for (CbFieldView View : RangeArray) + { + CbObjectView Range = View.AsObjectView(); + CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first); + CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second); + RangeOffset++; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + REQUIRE(DataAttachment); + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + + uint64_t Offset = 0; + for (const std::pair& Range : Ranges) + { + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second); + MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second); + CHECK(ActualRange.EqualBytes(RangeView)); + Offset += Range.second; + } + } + } + + { + // Multi-range + + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = + Instance.SpawnServerAndWaitUntilReady(fmt::format("--buildstore-enabled --system-dir {}", SystemRootPath)); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/builds/"); + + { + uint64_t RangeSizeSum = 0; + + const IoHash& RawHash = CompressedBlobsHashes.front(); + uint64_t BlobSize = CompressedBlobsSizes.front(); + + std::vector> Ranges = { + {BlobSize / 16 * 1, BlobSize / 20}, + {BlobSize / 16 * 3, BlobSize / 32}, + {BlobSize / 16 * 5, BlobSize / 16}, + {BlobSize - BlobSize / 16, BlobSize / 16 - 1}, + }; + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + RangeSizeSum += Range.second; + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + REQUIRE(Result); + IoBuffer Payload = Result.ResponsePayload; + REQUIRE(Payload.GetContentType() == ZenContentType::kCbPackage); + + CbPackage ResponsePackage = ParsePackageMessage(Payload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + CHECK_EQ(RangeArray.Num(), Ranges.size()); + size_t RangeOffset = 0; + for (CbFieldView View : RangeArray) + { + CbObjectView Range = View.AsObjectView(); + CHECK_EQ(Range["offset"sv].AsUInt64(), Ranges[RangeOffset].first); + CHECK_EQ(Range["length"sv].AsUInt64(), Ranges[RangeOffset].second); + RangeOffset++; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + REQUIRE(DataAttachment); + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + CHECK_EQ(RangeSizeSum, PayloadRanges.GetSize()); + + HttpClient::Response FullBlobResult = Client.Get(fmt::format("{}/{}/{}/blobs/{}", Namespace, Bucket, BuildId, RawHash), + HttpClient::Accept(ZenContentType::kCompressedBinary)); + REQUIRE(FullBlobResult); + + uint64_t Offset = 0; + for (const std::pair& Range : Ranges) + { + MemoryView ActualRange = FullBlobResult.ResponsePayload.GetView().Mid(Range.first, Range.second); + MemoryView RangeView = PayloadRanges.GetView().Mid(Offset, Range.second); + CHECK(ActualRange.EqualBytes(RangeView)); + Offset += Range.second; + } + } + } } namespace { diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index bf7afcc02..6ada085a5 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -71,7 +71,7 @@ HttpBuildStoreService::Initialize() m_Router.RegisterRoute( "{namespace}/{bucket}/{buildid}/blobs/{hash}", [this](HttpRouterRequest& Req) { GetBlobRequest(Req); }, - HttpVerb::kGet); + HttpVerb::kGet | HttpVerb::kPost); m_Router.RegisterRoute( "{namespace}/{bucket}/{buildid}/blobs/putBlobMetadata", @@ -161,14 +161,49 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) HttpContentType::kText, fmt::format("Invalid blob hash '{}'", Hash)); } - zen::HttpRanges Ranges; - bool HasRange = ServerRequest.TryGetRanges(Ranges); - if (Ranges.size() > 1) + + std::vector> OffsetAndLengthPairs; + if (ServerRequest.RequestVerb() == HttpVerb::kPost) { - // Only a single range is supported - return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, - HttpContentType::kText, - "Multiple ranges in blob request is not supported"); + CbObject RangePayload = ServerRequest.ReadPayloadObject(); + if (RangePayload) + { + CbArrayView RangesArray = RangePayload["ranges"sv].AsArrayView(); + OffsetAndLengthPairs.reserve(RangesArray.Num()); + for (CbFieldView FieldView : RangesArray) + { + CbObjectView RangeView = FieldView.AsObjectView(); + uint64_t RangeOffset = RangeView["offset"sv].AsUInt64(); + uint64_t RangeLength = RangeView["length"sv].AsUInt64(); + OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength)); + } + } + if (OffsetAndLengthPairs.empty()) + { + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Fetching blob without ranges must be done with the GET verb"); + } + } + else + { + HttpRanges Ranges; + bool HasRange = ServerRequest.TryGetRanges(Ranges); + if (HasRange) + { + if (Ranges.size() > 1) + { + // Only a single http range is supported, we have limited support for http multirange responses + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Multiple ranges in blob request is only supported for {} accept type", + ToString(HttpContentType::kCbPackage))); + } + const HttpRange& FirstRange = Ranges.front(); + OffsetAndLengthPairs.push_back(std::make_pair(FirstRange.Start, FirstRange.End - FirstRange.Start + 1)); + } } m_BuildStoreStats.BlobReadCount++; @@ -179,24 +214,65 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) HttpContentType::kText, fmt::format("Blob with hash '{}' could not be found", Hash)); } - // ZEN_INFO("Fetched blob {}. Size: {}", BlobHash, Blob.GetSize()); m_BuildStoreStats.BlobHitCount++; - if (HasRange) + + if (OffsetAndLengthPairs.empty()) { - const HttpRange& Range = Ranges.front(); - const uint64_t BlobSize = Blob.GetSize(); - const uint64_t MaxBlobSize = Range.Start < BlobSize ? BlobSize - Range.Start : 0; - const uint64_t RangeSize = Min(Range.End - Range.Start + 1, MaxBlobSize); - if (Range.Start + RangeSize > BlobSize) + return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); + } + + if (ServerRequest.AcceptContentType() == HttpContentType::kCbPackage) + { + const uint64_t BlobSize = Blob.GetSize(); + + CbPackage ResponsePackage; + std::vector RangeBuffers; + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + for (const std::pair& Range : OffsetAndLengthPairs) { - return ServerRequest.WriteResponse(HttpResponseCode::NoContent); + const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0; + const uint64_t RangeSize = Min(Range.second, MaxBlobSize); + if (Range.first + RangeSize <= BlobSize) + { + RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize)); + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, RangeSize); + } + Writer.EndObject(); + } } - Blob = IoBuffer(Blob, Range.Start, RangeSize); - return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob); + Writer.EndArray(); + + CompositeBuffer Ranges(RangeBuffers); + CbAttachment PayloadAttachment(std::move(Ranges), BlobHash); + Writer.AddAttachment("payload", PayloadAttachment); + + CbObject HeaderObject = Writer.Save(); + + ResponsePackage.AddAttachment(PayloadAttachment); + ResponsePackage.SetObject(HeaderObject); + + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage); + uint64_t ResponseSize = RpcResponseBuffer.GetSize(); + ZEN_UNUSED(ResponseSize); + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); } else { - return ServerRequest.WriteResponse(HttpResponseCode::OK, Blob.GetContentType(), Blob); + ZEN_ASSERT(OffsetAndLengthPairs.size() == 1); + const std::pair& OffsetAndLength = OffsetAndLengthPairs.front(); + const uint64_t BlobSize = Blob.GetSize(); + const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0; + const uint64_t RangeSize = Min(OffsetAndLength.second, MaxBlobSize); + if (OffsetAndLength.first + RangeSize > BlobSize) + { + return ServerRequest.WriteResponse(HttpResponseCode::NoContent); + } + Blob = IoBuffer(Blob, OffsetAndLength.first, RangeSize); + return ServerRequest.WriteResponse(HttpResponseCode::OK, ZenContentType::kBinary, Blob); } } -- cgit v1.2.3 From c7e0efb9c12f4607d4bc6a844a3e5bd3272bd839 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Sat, 28 Feb 2026 15:36:13 +0100 Subject: test running / reporting improvements (#797) **CI/CD improvements (validate.yml):** - Add test reporter (`ue-foundation/test-reporter@v2`) for all three platforms, rendering JUnit test results directly in PR check runs - Add "Trust workspace" step on Windows to fix git safe.directory ownership issue with self-hosted runners - Clean stale report files before each test run to prevent false failures from leftover XML - Broaden `paths-ignore` to skip builds for non-code changes (`*.md`, `LICENSE`, `.gitignore`, `docs/**`) **Test improvements:** - Convert `CHECK` to `REQUIRE` in several test suites (projectstore, integration, http) for fail-fast behavior - Mark some tests with `doctest::skip()` for selective execution - Skip httpclient transport tests pending investigation - Add `--noskip` option to `xmake test` task - Add `--repeat=` option to `xmake test` task, to run tests repeatedly N times or until there is a failure **xmake test output improvements:** - Add totals row to test summary table - Right-justify numeric columns in summary table --- src/zenhttp/httpclient_test.cpp | 4 ++-- src/zenserver-test/buildstore-tests.cpp | 16 +++++++------- src/zenserver-test/cache-tests.cpp | 10 ++------- src/zenserver-test/hub-tests.cpp | 2 +- src/zenserver-test/projectstore-tests.cpp | 34 ++++++++++++++--------------- src/zenserver-test/workspace-tests.cpp | 4 ++-- src/zenstore/cache/structuredcachestore.cpp | 2 +- 7 files changed, 33 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 509b56371..91b1a3414 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -1079,7 +1079,7 @@ struct FaultTcpServer } }; -TEST_CASE("httpclient.transport-faults") +TEST_CASE("httpclient.transport-faults" * doctest::skip()) { SUBCASE("connection reset before response") { @@ -1217,7 +1217,7 @@ TEST_CASE("httpclient.transport-faults") } } -TEST_CASE("httpclient.transport-faults-post") +TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) { constexpr size_t kPostBodySize = 256 * 1024; diff --git a/src/zenserver-test/buildstore-tests.cpp b/src/zenserver-test/buildstore-tests.cpp index ef48b2362..7cd31db06 100644 --- a/src/zenserver-test/buildstore-tests.cpp +++ b/src/zenserver-test/buildstore-tests.cpp @@ -389,7 +389,7 @@ TEST_CASE("buildstore.metadata") HttpClient::Response Result = Client.Post(fmt::format("{}/{}/{}/blobs/getBlobMetadata", Namespace, Bucket, BuildId), Payload, HttpClient::Accept(ZenContentType::kCbObject)); - CHECK(Result); + REQUIRE(Result); std::vector ResultMetadatas; @@ -570,7 +570,7 @@ TEST_CASE("buildstore.cache") { std::vector Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount; I++) { CHECK(Exists[I].HasBody); @@ -609,7 +609,7 @@ TEST_CASE("buildstore.cache") { std::vector Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount; I++) { CHECK(Exists[I].HasBody); @@ -617,7 +617,7 @@ TEST_CASE("buildstore.cache") } std::vector FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { @@ -638,7 +638,7 @@ TEST_CASE("buildstore.cache") { std::vector Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount * 2; I++) { CHECK(Exists[I].HasBody); @@ -649,7 +649,7 @@ TEST_CASE("buildstore.cache") CHECK_EQ(BlobCount, MetaDatas.size()); std::vector FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { @@ -672,7 +672,7 @@ TEST_CASE("buildstore.cache") CreateZenBuildStorageCache(Client, Stats, Namespace, Bucket, TempDir, GetTinyWorkerPool(EWorkloadType::Background))); std::vector Exists = Cache->BlobsExists(BuildId, BlobHashes); - CHECK(Exists.size() == BlobHashes.size()); + REQUIRE(Exists.size() == BlobHashes.size()); for (size_t I = 0; I < BlobCount * 2; I++) { CHECK(Exists[I].HasBody); @@ -691,7 +691,7 @@ TEST_CASE("buildstore.cache") CHECK_EQ(BlobCount, MetaDatas.size()); std::vector FetchedMetadatas = Cache->GetBlobMetadatas(BuildId, BlobHashes); - CHECK_EQ(BlobCount, FetchedMetadatas.size()); + REQUIRE_EQ(BlobCount, FetchedMetadatas.size()); for (size_t I = 0; I < BlobCount; I++) { diff --git a/src/zenserver-test/cache-tests.cpp b/src/zenserver-test/cache-tests.cpp index 0272d3797..745a89253 100644 --- a/src/zenserver-test/cache-tests.cpp +++ b/src/zenserver-test/cache-tests.cpp @@ -145,7 +145,7 @@ TEST_CASE("zcache.cbpackage") for (const zen::CbAttachment& LhsAttachment : LhsAttachments) { const zen::CbAttachment* RhsAttachment = Rhs.FindAttachment(LhsAttachment.GetHash()); - CHECK(RhsAttachment); + REQUIRE(RhsAttachment); zen::SharedBuffer LhsBuffer = LhsAttachment.AsCompressedBinary().Decompress(); CHECK(!LhsBuffer.IsNull()); @@ -1373,14 +1373,8 @@ TEST_CASE("zcache.rpc") } } -TEST_CASE("zcache.failing.upstream") +TEST_CASE("zcache.failing.upstream" * doctest::skip()) { - // This is an exploratory test that takes a long time to run, so lets skip it by default - if (true) - { - return; - } - using namespace std::literals; using namespace utils; diff --git a/src/zenserver-test/hub-tests.cpp b/src/zenserver-test/hub-tests.cpp index 42a5dcae4..bd85a5020 100644 --- a/src/zenserver-test/hub-tests.cpp +++ b/src/zenserver-test/hub-tests.cpp @@ -232,7 +232,7 @@ TEST_CASE("hub.lifecycle.children") TEST_SUITE_END(); -TEST_CASE("hub.consul.lifecycle") +TEST_CASE("hub.consul.lifecycle" * doctest::skip()) { zen::consul::ConsulProcess ConsulProc; ConsulProc.SpawnConsulAgent(); diff --git a/src/zenserver-test/projectstore-tests.cpp b/src/zenserver-test/projectstore-tests.cpp index 735aef159..487832405 100644 --- a/src/zenserver-test/projectstore-tests.cpp +++ b/src/zenserver-test/projectstore-tests.cpp @@ -71,7 +71,7 @@ TEST_CASE("project.basic") { auto Response = Http.Get("/prj/test"sv); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); CbObject ResponseObject = Response.AsObject(); @@ -92,7 +92,7 @@ TEST_CASE("project.basic") { auto Response = Http.Get(""sv); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); CbObject ResponseObject = Response.AsObject(); @@ -213,7 +213,7 @@ TEST_CASE("project.basic") auto Response = Http.Get(ChunkGetUri); REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); IoBuffer Data = Response.ResponsePayload; IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); @@ -235,13 +235,13 @@ TEST_CASE("project.basic") auto Response = Http.Get(ChunkGetUri, {{"Accept-Type", "application/x-ue-comp"}}); REQUIRE(Response); - CHECK(Response.StatusCode == HttpResponseCode::OK); + REQUIRE(Response.StatusCode == HttpResponseCode::OK); IoBuffer Data = Response.ResponsePayload; IoHash RawHash; uint64_t RawSize; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Data), RawHash, RawSize); - CHECK(Compressed); + REQUIRE(Compressed); IoBuffer DataDecompressed = Compressed.Decompress().AsIoBuffer(); IoBuffer ReferenceData = IoBufferBuilder::MakeFromFile(RootPath / BinPath); CHECK(RawSize == ReferenceData.GetSize()); @@ -436,13 +436,13 @@ TEST_CASE("project.remote") HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}", ProjectName), ProjectPayload); - CHECK(Response); + REQUIRE(Response); }; auto MakeOplog = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName) { HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName), IoBuffer{}); - CHECK(Response); + REQUIRE(Response); }; auto MakeOp = [](std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName, const CbPackage& OpPackage) { @@ -453,7 +453,7 @@ TEST_CASE("project.remote") HttpClient Http{UrlBase}; HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/new", ProjectName, OplogName), Body); - CHECK(Response); + REQUIRE(Response); }; MakeProject(Servers.GetInstance(0).GetBaseUri(), "proj0"); @@ -504,7 +504,7 @@ TEST_CASE("project.remote") HttpClient::Response Response = Http.Post(fmt::format("/prj/{}/oplog/{}/rpc", Project, Oplog), Payload, {{"Accept", "application/x-ue-cbpkg"}}); - CHECK(Response); + REQUIRE(Response); CbPackage ResponsePackage = ParsePackageMessage(Response.ResponsePayload); CHECK(ResponsePackage.GetAttachments().size() == AttachmentHashes.size()); for (auto A : ResponsePackage.GetAttachments()) @@ -519,7 +519,7 @@ TEST_CASE("project.remote") HttpClient Http{Servers.GetInstance(ServerIndex).GetBaseUri()}; HttpClient::Response Response = Http.Get(fmt::format("/prj/{}/oplog/{}/entries", Project, Oplog)); - CHECK(Response); + REQUIRE(Response); IoBuffer Payload(Response.ResponsePayload); CbObject OplogResonse = LoadCompactBinaryObject(Payload); @@ -541,7 +541,7 @@ TEST_CASE("project.remote") auto HttpWaitForCompletion = [](ZenServerInstance& Server, const HttpClient::Response& Response) { REQUIRE(Response); const uint64_t JobId = ParseInt(Response.AsText()).value_or(0); - CHECK(JobId != 0); + REQUIRE(JobId != 0); HttpClient Http{Server.GetBaseUri()}; @@ -549,10 +549,10 @@ TEST_CASE("project.remote") { HttpClient::Response StatusResponse = Http.Get(fmt::format("/admin/jobs/{}", JobId), {{"Accept", ToString(ZenContentType::kCbObject)}}); - CHECK(StatusResponse); + REQUIRE(StatusResponse); CbObject ResponseObject = StatusResponse.AsObject(); std::string_view Status = ResponseObject["Status"sv].AsString(); - CHECK(Status != "Aborted"sv); + REQUIRE(Status != "Aborted"sv); if (Status == "Complete"sv) { return; @@ -887,16 +887,16 @@ TEST_CASE("project.rpcappendop") Project.AddString("project"sv, ""sv); Project.AddString("projectfile"sv, ""sv); HttpClient::Response Response = Client.Post(fmt::format("/prj/{}", ProjectName), Project.Save()); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); }; auto MakeOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) { HttpClient::Response Response = Client.Post(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName)); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); }; auto GetOplog = [](HttpClient& Client, std::string_view ProjectName, std::string_view OplogName) { HttpClient::Response Response = Client.Get(fmt::format("/prj/{}/oplog/{}", ProjectName, OplogName)); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); return Response.AsObject(); }; @@ -910,7 +910,7 @@ TEST_CASE("project.rpcappendop") } Request.EndArray(); // "ops" HttpClient::Response Response = Client.Post(fmt::format("/prj/{}/oplog/{}/rpc", ProjectName, OplogName), Request.Save()); - CHECK_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); + REQUIRE_MESSAGE(Response.IsSuccess(), Response.ErrorMessage("")); CbObjectView ResponsePayload = Response.AsPackage().GetObject(); CbArrayView NeedArray = ResponsePayload["need"sv].AsArrayView(); diff --git a/src/zenserver-test/workspace-tests.cpp b/src/zenserver-test/workspace-tests.cpp index 7595d790a..aedadf0c3 100644 --- a/src/zenserver-test/workspace-tests.cpp +++ b/src/zenserver-test/workspace-tests.cpp @@ -514,9 +514,9 @@ TEST_CASE("workspaces.share") } IoBuffer BatchResponse = Client.Post(fmt::format("/ws/{}/{}/batch", WorkspaceId, ShareId), BuildChunkBatchRequest(BatchEntries)).ResponsePayload; - CHECK(BatchResponse); + REQUIRE(BatchResponse); std::vector BatchResult = ParseChunkBatchResponse(BatchResponse); - CHECK(BatchResult.size() == Files.size()); + REQUIRE(BatchResult.size() == Files.size()); for (const RequestChunkEntry& Request : BatchEntries) { IoBuffer Result = BatchResult[Request.CorrelationId]; diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 4e8475293..d8a5755c5 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -1551,7 +1551,7 @@ TEST_CASE("cachestore.size") } } -TEST_CASE("cachestore.threadedinsert") // * doctest::skip(true)) +TEST_CASE("cachestore.threadedinsert" * doctest::skip()) { // for (uint32_t i = 0; i < 100; ++i) { -- cgit v1.2.3 From f796ee9e650d5f73844f862ed51a6de6bb33c219 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Sat, 28 Feb 2026 15:36:50 +0100 Subject: subprocess tracking using Jobs on Windows/hub (#796) This change introduces job object support on Windows to be able to more accurately track and limit resource usage on storage instances created by the hub service. It also ensures that all child instances can be torn down reliably on exit. Also made it so hub tests no longer pop up console windows while running. --- src/zencore/include/zencore/process.h | 33 ++++++++++ src/zencore/process.cpp | 89 ++++++++++++++++++++++++++ src/zenserver/hub/hubservice.cpp | 49 +++++++++++++- src/zenserver/hub/hubservice.h | 7 ++ src/zenutil/include/zenutil/zenserverprocess.h | 12 +++- src/zenutil/zenserverprocess.cpp | 17 ++++- 6 files changed, 199 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index c51163a68..809312c7b 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -9,6 +9,10 @@ namespace zen { +#if ZEN_PLATFORM_WINDOWS +class JobObject; +#endif + /** Basic process abstraction */ class ProcessHandle @@ -46,6 +50,7 @@ private: /** Basic process creation */ + struct CreateProcOptions { enum @@ -63,6 +68,9 @@ struct CreateProcOptions const std::filesystem::path* WorkingDirectory = nullptr; uint32_t Flags = 0; std::filesystem::path StdoutFile; +#if ZEN_PLATFORM_WINDOWS + JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed +#endif }; #if ZEN_PLATFORM_WINDOWS @@ -99,6 +107,31 @@ private: std::vector m_ProcessHandles; }; +#if ZEN_PLATFORM_WINDOWS +/** Windows Job Object wrapper + * + * When configured with JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, the OS will + * terminate all assigned child processes when the job handle is closed + * (including abnormal termination of the owning process). This provides + * an OS-level guarantee against orphaned child processes. + */ +class JobObject +{ +public: + JobObject(); + ~JobObject(); + JobObject(const JobObject&) = delete; + JobObject& operator=(const JobObject&) = delete; + + void Initialize(); + bool AssignProcess(void* ProcessHandle); + [[nodiscard]] bool IsValid() const; + +private: + void* m_JobHandle = nullptr; +}; +#endif // ZEN_PLATFORM_WINDOWS + bool IsProcessRunning(int pid); bool IsProcessRunning(int pid, std::error_code& OutEc); int GetCurrentProcessId(); diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 4a2668912..226a94050 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -490,6 +490,8 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; + const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid(); + DWORD CreationFlags = 0; if (Options.Flags & CreateProcOptions::Flag_NewConsole) { @@ -503,6 +505,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma { CreationFlags |= CREATE_NEW_PROCESS_GROUP; } + if (AssignToJob) + { + CreationFlags |= CREATE_SUSPENDED; + } const wchar_t* WorkingDir = nullptr; if (Options.WorkingDirectory != nullptr) @@ -571,6 +577,15 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma return nullptr; } + if (AssignToJob) + { + if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess)) + { + ZEN_WARN("Failed to assign newly created process to job object"); + } + ResumeThread(ProcessInfo.hThread); + } + CloseHandle(ProcessInfo.hThread); return ProcessInfo.hProcess; } @@ -644,6 +659,8 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C }; PROCESS_INFORMATION ProcessInfo = {}; + const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid(); + if (Options.Flags & CreateProcOptions::Flag_NewConsole) { CreateProcFlags |= CREATE_NEW_CONSOLE; @@ -652,6 +669,10 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C { CreateProcFlags |= CREATE_NO_WINDOW; } + if (AssignToJob) + { + CreateProcFlags |= CREATE_SUSPENDED; + } ExtendableWideStringBuilder<256> CommandLineZ; CommandLineZ << CommandLine; @@ -679,6 +700,15 @@ CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view C return nullptr; } + if (AssignToJob) + { + if (!Options.AssignToJob->AssignProcess(ProcessInfo.hProcess)) + { + ZEN_WARN("Failed to assign newly created process to job object"); + } + ResumeThread(ProcessInfo.hThread); + } + CloseHandle(ProcessInfo.hThread); return ProcessInfo.hProcess; } @@ -845,6 +875,65 @@ ProcessMonitor::IsActive() const ////////////////////////////////////////////////////////////////////////// +#if ZEN_PLATFORM_WINDOWS +JobObject::JobObject() = default; + +JobObject::~JobObject() +{ + if (m_JobHandle) + { + CloseHandle(m_JobHandle); + m_JobHandle = nullptr; + } +} + +void +JobObject::Initialize() +{ + ZEN_ASSERT(m_JobHandle == nullptr, "JobObject already initialized"); + + m_JobHandle = CreateJobObjectW(nullptr, nullptr); + if (!m_JobHandle) + { + ZEN_WARN("Failed to create job object: {}", zen::GetLastError()); + return; + } + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION LimitInfo = {}; + LimitInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + + if (!SetInformationJobObject(m_JobHandle, JobObjectExtendedLimitInformation, &LimitInfo, sizeof(LimitInfo))) + { + ZEN_WARN("Failed to set job object limits: {}", zen::GetLastError()); + CloseHandle(m_JobHandle); + m_JobHandle = nullptr; + } +} + +bool +JobObject::AssignProcess(void* ProcessHandle) +{ + ZEN_ASSERT(m_JobHandle != nullptr, "JobObject not initialized"); + ZEN_ASSERT(ProcessHandle != nullptr, "ProcessHandle is null"); + + if (!AssignProcessToJobObject(m_JobHandle, ProcessHandle)) + { + ZEN_WARN("Failed to assign process to job object: {}", zen::GetLastError()); + return false; + } + + return true; +} + +bool +JobObject::IsValid() const +{ + return m_JobHandle != nullptr; +} +#endif // ZEN_PLATFORM_WINDOWS + +////////////////////////////////////////////////////////////////////////// + bool IsProcessRunning(int pid, std::error_code& OutEc) { diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index a00446a75..bf0e294c5 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -150,6 +151,10 @@ struct StorageServerInstance inline uint16_t GetBasePort() const { return m_ServerInstance.GetBasePort(); } +#if ZEN_PLATFORM_WINDOWS + void SetJobObject(JobObject* InJobObject) { m_JobObject = InJobObject; } +#endif + private: void WakeLocked(); RwLock m_Lock; @@ -161,6 +166,9 @@ private: std::filesystem::path m_TempDir; std::filesystem::path m_HydrationPath; ResourceMetrics m_ResourceMetrics; +#if ZEN_PLATFORM_WINDOWS + JobObject* m_JobObject = nullptr; +#endif void SpawnServerProcess(); @@ -191,6 +199,9 @@ StorageServerInstance::SpawnServerProcess() m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath()); m_ServerInstance.SetDataDir(m_BaseDir); +#if ZEN_PLATFORM_WINDOWS + m_ServerInstance.SetJobObject(m_JobObject); +#endif const uint16_t BasePort = m_ServerInstance.SpawnServerAndWaitUntilReady(); ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, BasePort); @@ -380,6 +391,21 @@ struct HttpHubService::Impl // flexibility, and to allow running multiple hubs on the same host if // necessary. m_RunEnvironment.SetNextPortNumber(21000); + +#if ZEN_PLATFORM_WINDOWS + if (m_UseJobObject) + { + m_JobObject.Initialize(); + if (m_JobObject.IsValid()) + { + ZEN_INFO("Job object initialized for hub service child process management"); + } + else + { + ZEN_WARN("Failed to initialize job object; child processes will not be auto-terminated on hub crash"); + } + } +#endif } void Cleanup() @@ -422,6 +448,12 @@ struct HttpHubService::Impl IsNewInstance = true; auto NewInstance = std::make_unique(m_RunEnvironment, ModuleId, m_FileHydrationPath, m_HydrationTempPath); +#if ZEN_PLATFORM_WINDOWS + if (m_JobObject.IsValid()) + { + NewInstance->SetJobObject(&m_JobObject); + } +#endif Instance = NewInstance.get(); m_Instances.emplace(std::string(ModuleId), std::move(NewInstance)); @@ -579,10 +611,15 @@ struct HttpHubService::Impl inline int GetInstanceLimit() { return m_InstanceLimit; } inline int GetMaxInstanceCount() { return m_MaxInstanceCount; } + bool m_UseJobObject = true; + private: - ZenServerEnvironment m_RunEnvironment; - std::filesystem::path m_FileHydrationPath; - std::filesystem::path m_HydrationTempPath; + ZenServerEnvironment m_RunEnvironment; + std::filesystem::path m_FileHydrationPath; + std::filesystem::path m_HydrationTempPath; +#if ZEN_PLATFORM_WINDOWS + JobObject m_JobObject; +#endif RwLock m_Lock; std::unordered_map> m_Instances; std::unordered_set m_DeprovisioningModules; @@ -817,6 +854,12 @@ HttpHubService::~HttpHubService() { } +void +HttpHubService::SetUseJobObject(bool Enable) +{ + m_Impl->m_UseJobObject = Enable; +} + const char* HttpHubService::BaseUri() const { diff --git a/src/zenserver/hub/hubservice.h b/src/zenserver/hub/hubservice.h index 1a5a8c57c..ef24bba69 100644 --- a/src/zenserver/hub/hubservice.h +++ b/src/zenserver/hub/hubservice.h @@ -28,6 +28,13 @@ public: void SetNotificationEndpoint(std::string_view UpstreamNotificationEndpoint, std::string_view InstanceId); + /** Enable or disable the use of a Windows Job Object for child process management. + * When enabled, all spawned child processes are assigned to a job object with + * JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, ensuring children are terminated if the hub + * crashes or is force-killed. Must be called before Initialize(). No-op on non-Windows. + */ + void SetUseJobObject(bool Enable); + private: HttpRequestRouter m_Router; diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index b781a03a9..954916fe2 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -97,9 +97,12 @@ struct ZenServerInstance inline int GetPid() const { return m_Process.Pid(); } inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; } void* GetProcessHandle() const { return m_Process.Handle(); } - bool IsRunning(); - bool Terminate(); - std::string GetLogOutput() const; +#if ZEN_PLATFORM_WINDOWS + void SetJobObject(JobObject* Job) { m_JobObject = Job; } +#endif + bool IsRunning(); + bool Terminate(); + std::string GetLogOutput() const; inline ServerMode GetServerMode() const { return m_ServerMode; } @@ -148,6 +151,9 @@ private: std::string m_Name; std::filesystem::path m_OutputCapturePath; std::filesystem::path m_ServerExecutablePath; +#if ZEN_PLATFORM_WINDOWS + JobObject* m_JobObject = nullptr; +#endif void CreateShutdownEvent(int BasePort); void SpawnServer(int BasePort, std::string_view AdditionalServerArgs, int WaitTimeoutMs); diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index 579ba450a..0f8ab223d 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -831,8 +831,15 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; const std::filesystem::path OutputPath = OpenConsole ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); - CreateProcOptions CreateOptions = {.WorkingDirectory = &CurrentDirectory, .Flags = CreationFlags, .StdoutFile = OutputPath}; - CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + .StdoutFile = OutputPath, +#if ZEN_PLATFORM_WINDOWS + .AssignToJob = m_JobObject, +#endif + }; + CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); #if ZEN_PLATFORM_WINDOWS if (!ChildPid) { @@ -841,6 +848,12 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { ZEN_DEBUG("Regular spawn failed - spawning elevated server"); CreateOptions.Flags |= CreateProcOptions::Flag_Elevated; + // ShellExecuteEx (used by the elevated path) does not support job object assignment + if (CreateOptions.AssignToJob) + { + ZEN_WARN("Elevated process spawn does not support job object assignment; child will not be auto-terminated on parent exit"); + CreateOptions.AssignToJob = nullptr; + } ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); } else -- cgit v1.2.3 From 4d01aaee0a45f4c9f96e8a4925eff696be98de8d Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Sun, 1 Mar 2026 12:40:20 +0100 Subject: added `--verbose` option to zenserver-test and `xmake test` (#798) * when `--verbose` is specified to zenserver-test, all child process output (typically, zenserver instances) is piped through to stdout. you can also pass `--verbose` to `xmake test` to accomplish the same thing. * this PR also consolidates all test runner `main` function logic (such as from zencore-test, zenhttp-test etc) into central implementation in zencore for consistency and ease of maintenance * also added extended utf8-tests including a fix to `Utf8ToWide()` --- src/zen/zen.cpp | 1 - src/zencompute-test/zencompute-test.cpp | 22 +--- src/zencore-test/zencore-test.cpp | 36 +------ src/zencore/include/zencore/testing.h | 2 + src/zencore/include/zencore/testutils.h | 27 +++++ src/zencore/string.cpp | 131 ++++++++++++++++++++++-- src/zencore/testing.cpp | 37 ++++++- src/zenhttp-test/zenhttp-test.cpp | 35 +------ src/zennet-test/zennet-test.cpp | 34 +----- src/zenremotestore-test/zenremotestore-test.cpp | 35 +------ src/zenserver-test/zenserver-test.cpp | 11 +- src/zenserver/main.cpp | 1 - src/zenstore-test/zenstore-test.cpp | 34 +----- src/zentelemetry-test/zentelemetry-test.cpp | 34 +----- src/zentest-appstub/zentest-appstub.cpp | 1 - src/zenutil-test/zenutil-test.cpp | 34 +----- src/zenutil/include/zenutil/zenserverprocess.h | 10 +- src/zenutil/zenserverprocess.cpp | 13 +-- 18 files changed, 228 insertions(+), 270 deletions(-) (limited to 'src') diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index dc37cb56b..ba8a76bc3 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -56,7 +56,6 @@ #include "progressbar.h" #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include #endif diff --git a/src/zencompute-test/zencompute-test.cpp b/src/zencompute-test/zencompute-test.cpp index 237812e12..60aaeab1d 100644 --- a/src/zencompute-test/zencompute-test.cpp +++ b/src/zencompute-test/zencompute-test.cpp @@ -1,31 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include -#include -#include -#include +#include -#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC -# include -# include -# include -#endif - -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -#endif +#include int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { #if ZEN_WITH_TESTS - zen::zencompute_forcelinktests(); - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zencompute-test", zen::zencompute_forcelinktests); #else return 0; #endif diff --git a/src/zencore-test/zencore-test.cpp b/src/zencore-test/zencore-test.cpp index 68fc940ee..3d9a79283 100644 --- a/src/zencore-test/zencore-test.cpp +++ b/src/zencore-test/zencore-test.cpp @@ -1,47 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -// zencore-test.cpp : Defines the entry point for the console application. -// - -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zencore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zencore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zencore-test", zen::zencore_forcelinktests); #else return 0; #endif diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h index a00ee3166..43bdbbffe 100644 --- a/src/zencore/include/zencore/testing.h +++ b/src/zencore/include/zencore/testing.h @@ -59,6 +59,8 @@ private: return Runner.Run(); \ }() +int RunTestMain(int argc, char* argv[], const char* traceName, void (*forceLink)()); + } // namespace zen::testing #endif diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h index e2a4f8346..2a789d18f 100644 --- a/src/zencore/include/zencore/testutils.h +++ b/src/zencore/include/zencore/testutils.h @@ -59,6 +59,33 @@ struct TrueType static const bool Enabled = true; }; +namespace utf8test { + + // 2-byte UTF-8 (Latin extended) + static constexpr const char kLatin[] = u8"café_résumé"; + static constexpr const wchar_t kLatinW[] = L"café_résumé"; + + // 2-byte UTF-8 (Cyrillic) + static constexpr const char kCyrillic[] = u8"данные"; + static constexpr const wchar_t kCyrillicW[] = L"данные"; + + // 3-byte UTF-8 (CJK) + static constexpr const char kCJK[] = u8"日本語"; + static constexpr const wchar_t kCJKW[] = L"日本語"; + + // Mixed scripts + static constexpr const char kMixed[] = u8"zen_éд日"; + static constexpr const wchar_t kMixedW[] = L"zen_éд日"; + + // 4-byte UTF-8 (supplementary plane) — string tests only, NOT filesystem + static constexpr const char kEmoji[] = u8"📦"; + static constexpr const wchar_t kEmojiW[] = L"📦"; + + // BMP-only test strings suitable for filesystem use + static constexpr const char* kFilenameSafe[] = {kLatin, kCyrillic, kCJK, kMixed}; + +} // namespace utf8test + } // namespace zen #endif // ZEN_WITH_TESTS diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index a9aed6309..ab1c7de58 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -184,7 +185,21 @@ Utf8ToWide(const std::u8string_view& Str8, WideStringBuilderBase& OutString) if (!ByteCount) { +#if ZEN_SIZEOF_WCHAR_T == 2 + if (CurrentOutChar > 0xFFFF) + { + // Supplementary plane: emit a UTF-16 surrogate pair + uint32_t Adjusted = uint32_t(CurrentOutChar - 0x10000); + OutString.Append(wchar_t(0xD800 + (Adjusted >> 10))); + OutString.Append(wchar_t(0xDC00 + (Adjusted & 0x3FF))); + } + else + { + OutString.Append(wchar_t(CurrentOutChar)); + } +#else OutString.Append(wchar_t(CurrentOutChar)); +#endif CurrentOutChar = 0; } } @@ -967,33 +982,131 @@ TEST_CASE("ExtendableWideStringBuilder") TEST_CASE("utf8") { + using namespace utf8test; + SUBCASE("utf8towide") { - // TODO: add more extensive testing here - this covers a very small space - WideStringBuilder<32> wout; Utf8ToWide(u8"abcdefghi", wout); CHECK(StringEquals(L"abcdefghi", wout.c_str())); wout.Reset(); + Utf8ToWide(u8"abc\xC3\xA4\xC3\xB6\xC3\xBC", wout); + CHECK(StringEquals(L"abc\u00E4\u00F6\u00FC", wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kLatin), wout); + CHECK(StringEquals(kLatinW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kCyrillic), wout); + CHECK(StringEquals(kCyrillicW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kCJK), wout); + CHECK(StringEquals(kCJKW, wout.c_str())); + + wout.Reset(); + Utf8ToWide(std::string_view(kMixed), wout); + CHECK(StringEquals(kMixedW, wout.c_str())); - Utf8ToWide(u8"abc���", wout); - CHECK(StringEquals(L"abc���", wout.c_str())); + wout.Reset(); + Utf8ToWide(std::string_view(kEmoji), wout); + CHECK(StringEquals(kEmojiW, wout.c_str())); } SUBCASE("widetoutf8") { - // TODO: add more extensive testing here - this covers a very small space - - StringBuilder<32> out; + StringBuilder<64> out; WideToUtf8(L"abcdefghi", out); CHECK(StringEquals("abcdefghi", out.c_str())); out.Reset(); + WideToUtf8(kLatinW, out); + CHECK(StringEquals(kLatin, out.c_str())); - WideToUtf8(L"abc���", out); - CHECK(StringEquals(u8"abc���", out.c_str())); + out.Reset(); + WideToUtf8(kCyrillicW, out); + CHECK(StringEquals(kCyrillic, out.c_str())); + + out.Reset(); + WideToUtf8(kCJKW, out); + CHECK(StringEquals(kCJK, out.c_str())); + + out.Reset(); + WideToUtf8(kMixedW, out); + CHECK(StringEquals(kMixed, out.c_str())); + + out.Reset(); + WideToUtf8(kEmojiW, out); + CHECK(StringEquals(kEmoji, out.c_str())); + } + + SUBCASE("roundtrip") + { + // UTF-8 -> Wide -> UTF-8 identity + const char* Utf8Strings[] = {kLatin, kCyrillic, kCJK, kMixed, kEmoji}; + for (const char* Utf8Str : Utf8Strings) + { + ExtendableWideStringBuilder<64> Wide; + Utf8ToWide(std::string_view(Utf8Str), Wide); + + ExtendableStringBuilder<64> Back; + WideToUtf8(std::wstring_view(Wide.c_str()), Back); + CHECK(StringEquals(Utf8Str, Back.c_str())); + } + + // Wide -> UTF-8 -> Wide identity + const wchar_t* WideStrings[] = {kLatinW, kCyrillicW, kCJKW, kMixedW, kEmojiW}; + for (const wchar_t* WideStr : WideStrings) + { + ExtendableStringBuilder<64> Utf8; + WideToUtf8(std::wstring_view(WideStr), Utf8); + + ExtendableWideStringBuilder<64> Back; + Utf8ToWide(std::string_view(Utf8.c_str()), Back); + CHECK(StringEquals(WideStr, Back.c_str())); + } + + // Empty string round-trip + { + ExtendableWideStringBuilder<8> Wide; + Utf8ToWide(std::string_view(""), Wide); + CHECK(Wide.Size() == 0); + + ExtendableStringBuilder<8> Narrow; + WideToUtf8(std::wstring_view(L""), Narrow); + CHECK(Narrow.Size() == 0); + } + } + + SUBCASE("IsValidUtf8") + { + // Valid inputs + CHECK(IsValidUtf8("")); + CHECK(IsValidUtf8("hello world")); + CHECK(IsValidUtf8(kLatin)); + CHECK(IsValidUtf8(kCyrillic)); + CHECK(IsValidUtf8(kCJK)); + CHECK(IsValidUtf8(kMixed)); + CHECK(IsValidUtf8(kEmoji)); + + // Invalid: truncated 2-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xC3", 1))); + + // Invalid: truncated 3-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xE6\x97", 2))); + + // Invalid: truncated 4-byte sequence + CHECK(!IsValidUtf8(std::string_view("\xF0\x9F\x93", 3))); + + // Invalid: bad start byte + CHECK(!IsValidUtf8(std::string_view("\xFF", 1))); + CHECK(!IsValidUtf8(std::string_view("\xFE", 1))); + + // Invalid: overlong encoding of '/' (U+002F) + CHECK(!IsValidUtf8(std::string_view("\xC0\xAF", 2))); } } diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index ef8fb0480..6000bd95c 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -1,18 +1,23 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#define ZEN_TEST_WITH_RUNNER 1 + #include "zencore/testing.h" + +#include "zencore/filesystem.h" #include "zencore/logging.h" +#include "zencore/process.h" +#include "zencore/trace.h" #if ZEN_WITH_TESTS # include +# include # include # include # include # include -# include - namespace zen::testing { using namespace std::literals; @@ -149,6 +154,34 @@ TestRunner::Run() return m_Impl->Session.run(); } +int +RunTestMain(int argc, char* argv[], [[maybe_unused]] const char* traceName, void (*forceLink)()) +{ +# if ZEN_PLATFORM_WINDOWS + setlocale(LC_ALL, "en_us.UTF8"); +# endif + + forceLink(); + +# if ZEN_PLATFORM_LINUX + zen::IgnoreChildSignals(); +# endif + +# if ZEN_WITH_TRACE + zen::TraceInit(traceName); + zen::TraceOptions TraceCommandlineOptions; + if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) + { + TraceConfigure(TraceCommandlineOptions); + } +# endif + + zen::logging::InitializeLogging(); + zen::MaximizeOpenFileCount(); + + return ZEN_RUN_TESTS(argc, argv); +} + } // namespace zen::testing #endif // ZEN_WITH_TESTS diff --git a/src/zenhttp-test/zenhttp-test.cpp b/src/zenhttp-test/zenhttp-test.cpp index c18759beb..b4b406ac8 100644 --- a/src/zenhttp-test/zenhttp-test.cpp +++ b/src/zenhttp-test/zenhttp-test.cpp @@ -1,44 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include -#include +#include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif +#include int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenhttp_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenhttp-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenhttp-test", zen::zenhttp_forcelinktests); #else return 0; #endif diff --git a/src/zennet-test/zennet-test.cpp b/src/zennet-test/zennet-test.cpp index bc3b8e8e9..1283eb820 100644 --- a/src/zennet-test/zennet-test.cpp +++ b/src/zennet-test/zennet-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zennet_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zennet-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zennet-test", zen::zennet_forcelinktests); #else return 0; #endif diff --git a/src/zenremotestore-test/zenremotestore-test.cpp b/src/zenremotestore-test/zenremotestore-test.cpp index 5db185041..dc47c5aed 100644 --- a/src/zenremotestore-test/zenremotestore-test.cpp +++ b/src/zenremotestore-test/zenremotestore-test.cpp @@ -1,46 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenremotestore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenremotestore-test", zen::zenremotestore_forcelinktests); #else return 0; #endif diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index 4120dec1a..c7ce633d3 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -4,7 +4,6 @@ #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include "zenserver-test.h" # include @@ -97,6 +96,7 @@ main(int argc, char** argv) // somehow in the future std::string ServerClass; + bool Verbose = false; for (int i = 1; i < argc; ++i) { @@ -107,10 +107,19 @@ main(int argc, char** argv) ServerClass = argv[++i]; } } + else if (argv[i] == "--verbose"sv) + { + Verbose = true; + } } zen::tests::TestEnv.InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass); + if (Verbose) + { + zen::tests::TestEnv.SetPassthroughOutput(true); + } + ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir); zen::testing::TestRunner Runner; diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 571dd3b4f..c764cbde6 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -41,7 +41,6 @@ // in some shared code into the executable #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include #endif diff --git a/src/zenstore-test/zenstore-test.cpp b/src/zenstore-test/zenstore-test.cpp index c055dbb64..875373a9d 100644 --- a/src/zenstore-test/zenstore-test.cpp +++ b/src/zenstore-test/zenstore-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenstore_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenstore-test", zen::zenstore_forcelinktests); #else return 0; #endif diff --git a/src/zentelemetry-test/zentelemetry-test.cpp b/src/zentelemetry-test/zentelemetry-test.cpp index 83fd549db..5a2ac74de 100644 --- a/src/zentelemetry-test/zentelemetry-test.cpp +++ b/src/zentelemetry-test/zentelemetry-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zentelemetry_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenstore-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zentelemetry-test", zen::zentelemetry_forcelinktests); #else return 0; #endif diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 926580d96..67fbef532 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -9,7 +9,6 @@ #include #if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 # include #endif diff --git a/src/zenutil-test/zenutil-test.cpp b/src/zenutil-test/zenutil-test.cpp index f5cfd5a72..e2b6ac9bd 100644 --- a/src/zenutil-test/zenutil-test.cpp +++ b/src/zenutil-test/zenutil-test.cpp @@ -1,45 +1,15 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include -#include -#include +#include #include #include -#if ZEN_WITH_TESTS -# define ZEN_TEST_WITH_RUNNER 1 -# include -# include -#endif - int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { -#if ZEN_PLATFORM_WINDOWS - setlocale(LC_ALL, "en_us.UTF8"); -#endif // ZEN_PLATFORM_WINDOWS - #if ZEN_WITH_TESTS - zen::zenutil_forcelinktests(); - -# if ZEN_PLATFORM_LINUX - zen::IgnoreChildSignals(); -# endif - -# if ZEN_WITH_TRACE - zen::TraceInit("zenutil-test"); - zen::TraceOptions TraceCommandlineOptions; - if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) - { - TraceConfigure(TraceCommandlineOptions); - } -# endif // ZEN_WITH_TRACE - - zen::logging::InitializeLogging(); - zen::MaximizeOpenFileCount(); - - return ZEN_RUN_TESTS(argc, argv); + return zen::testing::RunTestMain(argc, argv, "zenutil-test", zen::zenutil_forcelinktests); #else return 0; #endif diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index 954916fe2..e81b154e8 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -46,6 +46,9 @@ public: inline std::string_view GetServerClass() const { return m_ServerClass; } inline uint16_t GetNewPortNumber() { return m_NextPortNumber.fetch_add(1); } + void SetPassthroughOutput(bool Enable) { m_PassthroughOutput = Enable; } + bool IsPassthroughOutput() const { return m_PassthroughOutput; } + // The defaults will work for a single root process only. For hierarchical // setups (e.g., hub managing storage servers), we need to be able to // allocate distinct child IDs and ports to avoid overlap/conflicts. @@ -55,9 +58,10 @@ public: private: std::filesystem::path m_ProgramBaseDir; std::filesystem::path m_ChildProcessBaseDir; - bool m_IsInitialized = false; - bool m_IsTestInstance = false; - bool m_IsHubInstance = false; + bool m_IsInitialized = false; + bool m_IsTestInstance = false; + bool m_IsHubInstance = false; + bool m_PassthroughOutput = false; std::string m_ServerClass; std::atomic_uint16_t m_NextPortNumber{20000}; }; diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index 0f8ab223d..e127a92d7 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -829,12 +829,13 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); const std::filesystem::path Executable = m_ServerExecutablePath.empty() ? (BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL) : m_ServerExecutablePath; - const std::filesystem::path OutputPath = - OpenConsole ? std::filesystem::path{} : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); - CreateProcOptions CreateOptions = { - .WorkingDirectory = &CurrentDirectory, - .Flags = CreationFlags, - .StdoutFile = OutputPath, + const std::filesystem::path OutputPath = (OpenConsole || m_Env.IsPassthroughOutput()) + ? std::filesystem::path{} + : std::filesystem::temp_directory_path() / ("zenserver_" + m_Name + ".log"); + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + .StdoutFile = OutputPath, #if ZEN_PLATFORM_WINDOWS .AssignToJob = m_JobObject, #endif -- cgit v1.2.3 From d604351cb5dc3032a7cb8c84d6ad5f1480325e5c Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 2 Mar 2026 09:37:14 +0100 Subject: Add test suites (#799) Makes all test cases part of a test suite. Test suites are named after the module and the name of the file containing the implementation of the test. * This allows for better and more predictable filtering of which test cases to run which should also be able to reduce the time CI spends in tests since it can filter on the tests for that particular module. Also improves `xmake test` behaviour: * instead of an explicit list of projects just enumerate the test projects which are available based on build system state * also introduces logic to avoid running `xmake config` unnecessarily which would invalidate the existing build and do lots of unnecessary work since dependencies were invalidated by the updated config * also invokes build only for the chosen test targets As a bonus, also adds `xmake sln --open` which allows opening IDE after generation of solution/xmake project is done. --- src/zencore/base64.cpp | 4 ++ src/zencore/basicfile.cpp | 4 ++ src/zencore/blake3.cpp | 4 ++ src/zencore/callstack.cpp | 4 ++ src/zencore/compactbinary.cpp | 8 +--- src/zencore/compactbinarybuilder.cpp | 4 ++ src/zencore/compactbinaryjson.cpp | 4 ++ src/zencore/compactbinarypackage.cpp | 4 ++ src/zencore/compactbinaryvalidation.cpp | 4 ++ src/zencore/compactbinaryyaml.cpp | 4 ++ src/zencore/compositebuffer.cpp | 5 +++ src/zencore/compress.cpp | 4 ++ src/zencore/crypto.cpp | 4 ++ src/zencore/filesystem.cpp | 4 ++ src/zencore/include/zencore/testing.h | 7 ++-- src/zencore/intmath.cpp | 4 ++ src/zencore/iobuffer.cpp | 4 ++ src/zencore/jobqueue.cpp | 4 ++ src/zencore/logging.cpp | 4 ++ src/zencore/md5.cpp | 4 ++ src/zencore/memoryview.cpp | 4 ++ src/zencore/mpscqueue.cpp | 2 + src/zencore/parallelwork.cpp | 4 ++ src/zencore/refcount.cpp | 4 ++ src/zencore/sha1.cpp | 4 ++ src/zencore/sharedbuffer.cpp | 4 ++ src/zencore/stream.cpp | 4 ++ src/zencore/string.cpp | 4 ++ src/zencore/testing.cpp | 47 +++++++++++++++++----- src/zencore/uid.cpp | 4 ++ src/zencore/workthreadpool.cpp | 4 ++ src/zencore/zencore.cpp | 2 +- src/zenhttp/clients/httpclientcommon.cpp | 4 ++ src/zenhttp/httpclient.cpp | 4 ++ src/zenhttp/httpclient_test.cpp | 4 ++ src/zenhttp/httpserver.cpp | 4 ++ src/zenhttp/packageformat.cpp | 4 ++ src/zenhttp/security/passwordsecurity.cpp | 5 +++ src/zenhttp/servers/wstest.cpp | 4 ++ src/zennet/statsdclient.cpp | 4 ++ src/zenremotestore/builds/buildmanifest.cpp | 4 ++ src/zenremotestore/builds/buildsavedstate.cpp | 4 ++ .../builds/buildstorageoperations.cpp | 4 ++ src/zenremotestore/chunking/chunkblock.cpp | 4 ++ src/zenremotestore/chunking/chunkedcontent.cpp | 4 ++ src/zenremotestore/chunking/chunkedfile.cpp | 4 ++ src/zenremotestore/chunking/chunkingcache.cpp | 4 ++ src/zenremotestore/filesystemutils.cpp | 4 ++ src/zenserver-test/buildstore-tests.cpp | 4 ++ src/zenserver-test/cache-tests.cpp | 4 ++ src/zenserver-test/cacherequests.cpp | 4 ++ src/zenserver-test/function-tests.cpp | 4 ++ src/zenserver-test/hub-tests.cpp | 6 +-- src/zenserver-test/projectstore-tests.cpp | 4 ++ src/zenserver-test/workspace-tests.cpp | 4 ++ src/zenserver-test/zenserver-test.cpp | 5 +++ src/zenstore/blockstore.cpp | 4 ++ src/zenstore/buildstore/buildstore.cpp | 4 ++ src/zenstore/cache/cachepolicy.cpp | 5 +++ src/zenstore/cache/structuredcachestore.cpp | 4 ++ src/zenstore/cas.cpp | 4 ++ src/zenstore/compactcas.cpp | 4 ++ src/zenstore/filecas.cpp | 4 ++ src/zenstore/gc.cpp | 4 ++ src/zenstore/projectstore.cpp | 4 ++ src/zenstore/workspaces.cpp | 4 ++ src/zentelemetry/otlptrace.cpp | 4 ++ src/zentelemetry/stats.cpp | 4 ++ src/zenutil/config/commandlineoptions.cpp | 4 ++ src/zenutil/rpcrecording.cpp | 2 +- src/zenutil/wildcard.cpp | 4 ++ 71 files changed, 311 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp index fdf5f2d66..96e121799 100644 --- a/src/zencore/base64.cpp +++ b/src/zencore/base64.cpp @@ -180,6 +180,8 @@ template bool Base64::Decode(const wchar_t* Source, uint32_t Length, ui using namespace std::string_literals; +TEST_SUITE_BEGIN("core.base64"); + TEST_CASE("Base64") { auto EncodeString = [](std::string_view Input) -> std::string { @@ -290,6 +292,8 @@ TEST_CASE("Base64") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp index bd4d119fb..9dcf7663a 100644 --- a/src/zencore/basicfile.cpp +++ b/src/zencore/basicfile.cpp @@ -888,6 +888,8 @@ WriteToTempFile(CompositeBuffer&& Buffer, const std::filesystem::path& Path) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.basicfile"); + TEST_CASE("BasicFile") { ScopedCurrentDirectoryChange _; @@ -1081,6 +1083,8 @@ TEST_CASE("BasicFileBuffer") } } +TEST_SUITE_END(); + void basicfile_forcelink() { diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp index 054f0d3a0..123918de5 100644 --- a/src/zencore/blake3.cpp +++ b/src/zencore/blake3.cpp @@ -200,6 +200,8 @@ BLAKE3Stream::GetHash() // return text; // } +TEST_SUITE_BEGIN("core.blake3"); + TEST_CASE("BLAKE3") { SUBCASE("Basics") @@ -237,6 +239,8 @@ TEST_CASE("BLAKE3") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/callstack.cpp b/src/zencore/callstack.cpp index 8aa1111bf..ee0b0625a 100644 --- a/src/zencore/callstack.cpp +++ b/src/zencore/callstack.cpp @@ -260,6 +260,8 @@ GetCallstackRaw(void* CaptureBuffer, int FramesToSkip, int FramesToCapture) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.callstack"); + TEST_CASE("Callstack.Basic") { void* Addresses[4]; @@ -272,6 +274,8 @@ TEST_CASE("Callstack.Basic") } } +TEST_SUITE_END(); + void callstack_forcelink() { diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp index b43cc18f1..9c81305d0 100644 --- a/src/zencore/compactbinary.cpp +++ b/src/zencore/compactbinary.cpp @@ -1512,6 +1512,8 @@ uson_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinary"); + TEST_CASE("guid") { using namespace std::literals; @@ -1704,8 +1706,6 @@ TEST_CASE("uson.datetime") ////////////////////////////////////////////////////////////////////////// -TEST_SUITE_BEGIN("core.datetime"); - TEST_CASE("core.datetime.compare") { DateTime T1(2000, 12, 13); @@ -1732,10 +1732,6 @@ TEST_CASE("core.datetime.add") CHECK(dT + T1 - T2 == dT1); } -TEST_SUITE_END(); - -TEST_SUITE_BEGIN("core.timespan"); - TEST_CASE("core.timespan.compare") { TimeSpan T1(1000); diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp index 63c0b9c5c..a9ba30750 100644 --- a/src/zencore/compactbinarybuilder.cpp +++ b/src/zencore/compactbinarybuilder.cpp @@ -710,6 +710,8 @@ usonbuilder_forcelink() // return ""; // } +TEST_SUITE_BEGIN("core.compactbinarybuilder"); + TEST_CASE("usonbuilder.object") { using namespace std::literals; @@ -1530,6 +1532,8 @@ TEST_CASE("usonbuilder.stream") CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None); } } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compactbinaryjson.cpp b/src/zencore/compactbinaryjson.cpp index abbec360a..da560a449 100644 --- a/src/zencore/compactbinaryjson.cpp +++ b/src/zencore/compactbinaryjson.cpp @@ -654,6 +654,8 @@ cbjson_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryjson"); + TEST_CASE("uson.json") { using namespace std::literals; @@ -872,6 +874,8 @@ TEST_CASE("json.uson") } } +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS } // namespace zen diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp index ffe64f2e9..56a292ca6 100644 --- a/src/zencore/compactbinarypackage.cpp +++ b/src/zencore/compactbinarypackage.cpp @@ -805,6 +805,8 @@ usonpackage_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinarypackage"); + TEST_CASE("usonpackage") { using namespace std::literals; @@ -1343,6 +1345,8 @@ TEST_CASE("usonpackage.invalidpackage") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp index d7292f405..3e78f8ef1 100644 --- a/src/zencore/compactbinaryvalidation.cpp +++ b/src/zencore/compactbinaryvalidation.cpp @@ -753,10 +753,14 @@ usonvalidation_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryvalidation"); + TEST_CASE("usonvalidation") { SUBCASE("Basic") {} } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compactbinaryyaml.cpp b/src/zencore/compactbinaryyaml.cpp index b308af418..b7f2c55df 100644 --- a/src/zencore/compactbinaryyaml.cpp +++ b/src/zencore/compactbinaryyaml.cpp @@ -412,6 +412,8 @@ cbyaml_forcelink() { } +TEST_SUITE_BEGIN("core.compactbinaryyaml"); + TEST_CASE("uson.yaml") { using namespace std::literals; @@ -524,6 +526,8 @@ mixed_seq: )"sv); } } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp index 252ac9045..ed2b16384 100644 --- a/src/zencore/compositebuffer.cpp +++ b/src/zencore/compositebuffer.cpp @@ -297,6 +297,9 @@ CompositeBuffer::IterateRange(uint64_t Offset, } #if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("core.compositebuffer"); + TEST_CASE("CompositeBuffer Null") { CompositeBuffer Buffer; @@ -462,6 +465,8 @@ TEST_CASE("CompositeBuffer Composite") TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2); } +TEST_SUITE_END(); + void compositebuffer_forcelink() { diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp index 25ed0fc46..6aa0adce0 100644 --- a/src/zencore/compress.cpp +++ b/src/zencore/compress.cpp @@ -2420,6 +2420,8 @@ private: #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.compress"); + TEST_CASE("CompressedBuffer") { uint8_t Zeroes[1024]{}; @@ -2967,6 +2969,8 @@ TEST_CASE("CompressedBufferReader") } } +TEST_SUITE_END(); + void compress_forcelink() { diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp index 09eebb6ae..049854b42 100644 --- a/src/zencore/crypto.cpp +++ b/src/zencore/crypto.cpp @@ -449,6 +449,8 @@ crypto_forcelink() { } +TEST_SUITE_BEGIN("core.crypto"); + TEST_CASE("crypto.bits") { using CryptoBits256Bit = CryptoBits<256>; @@ -500,6 +502,8 @@ TEST_CASE("crypto.aes") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 03398860b..9885b2ada 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -3309,6 +3309,8 @@ filesystem_forcelink() { } +TEST_SUITE_BEGIN("core.filesystem"); + TEST_CASE("filesystem") { using namespace std::filesystem; @@ -3603,6 +3605,8 @@ TEST_CASE("SharedMemory") CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h index 43bdbbffe..8410216c4 100644 --- a/src/zencore/include/zencore/testing.h +++ b/src/zencore/include/zencore/testing.h @@ -43,8 +43,9 @@ public: TestRunner(); ~TestRunner(); - int ApplyCommandLine(int argc, char const* const* argv); - int Run(); + void SetDefaultSuiteFilter(const char* Pattern); + int ApplyCommandLine(int Argc, char const* const* Argv); + int Run(); private: struct Impl; @@ -59,7 +60,7 @@ private: return Runner.Run(); \ }() -int RunTestMain(int argc, char* argv[], const char* traceName, void (*forceLink)()); +int RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)()); } // namespace zen::testing #endif diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp index 32f82b486..fedf76edc 100644 --- a/src/zencore/intmath.cpp +++ b/src/zencore/intmath.cpp @@ -19,6 +19,8 @@ intmath_forcelink() { } +TEST_SUITE_BEGIN("core.intmath"); + TEST_CASE("intmath") { CHECK(FloorLog2(0x00) == 0); @@ -66,6 +68,8 @@ TEST_CASE("intmath") CHECK(ByteSwap(uint64_t(0x214d'6172'7469'6e21ull)) == 0x216e'6974'7261'4d21ull); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index 1c31d6620..c47c54981 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -719,6 +719,8 @@ iobuffer_forcelink() { } +TEST_SUITE_BEGIN("core.iobuffer"); + TEST_CASE("IoBuffer") { zen::IoBuffer buffer1; @@ -756,6 +758,8 @@ TEST_CASE("IoBuffer.mmap") # endif } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp index 75c1be42b..35724b07a 100644 --- a/src/zencore/jobqueue.cpp +++ b/src/zencore/jobqueue.cpp @@ -460,6 +460,8 @@ jobqueue_forcelink() { } +TEST_SUITE_BEGIN("core.jobqueue"); + TEST_CASE("JobQueue") { std::unique_ptr Queue(MakeJobQueue(2, "queue")); @@ -580,6 +582,8 @@ TEST_CASE("JobQueue") } JobsLatch.Wait(); } + +TEST_SUITE_END(); #endif } // namespace zen diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index e79c4b41c..e960a2729 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -540,6 +540,8 @@ logging_forcelink() using namespace std::literals; +TEST_SUITE_BEGIN("core.logging"); + TEST_CASE("simple.bread") { ExtendableStringBuilder<256> Crumbs; @@ -588,6 +590,8 @@ TEST_CASE("simple.bread") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp index 4ec145697..3baee91c2 100644 --- a/src/zencore/md5.cpp +++ b/src/zencore/md5.cpp @@ -437,6 +437,8 @@ md5_forcelink() // return md5text; // } +TEST_SUITE_BEGIN("core.md5"); + TEST_CASE("MD5") { using namespace std::literals; @@ -458,6 +460,8 @@ TEST_CASE("MD5") CHECK(Output.compare(Buffer)); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/memoryview.cpp b/src/zencore/memoryview.cpp index 1f6a6996c..1654b1766 100644 --- a/src/zencore/memoryview.cpp +++ b/src/zencore/memoryview.cpp @@ -18,6 +18,8 @@ namespace zen { #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.memoryview"); + TEST_CASE("MemoryView") { { @@ -35,6 +37,8 @@ TEST_CASE("MemoryView") CHECK(MakeMemoryView({1.0f, 1.2f}).GetSize() == 8); } +TEST_SUITE_END(); + void memory_forcelink() { diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp index 29c76c3ca..f749f1c90 100644 --- a/src/zencore/mpscqueue.cpp +++ b/src/zencore/mpscqueue.cpp @@ -8,6 +8,7 @@ namespace zen { #if ZEN_WITH_TESTS && 0 +TEST_SUITE_BEGIN("core.mpscqueue"); TEST_CASE("mpsc") { MpscQueue Queue; @@ -15,6 +16,7 @@ TEST_CASE("mpsc") std::optional Value = Queue.Dequeue(); CHECK_EQ(Value, "hello"); } +TEST_SUITE_END(); #endif void diff --git a/src/zencore/parallelwork.cpp b/src/zencore/parallelwork.cpp index d86d5815f..94696f479 100644 --- a/src/zencore/parallelwork.cpp +++ b/src/zencore/parallelwork.cpp @@ -157,6 +157,8 @@ ParallelWork::RethrowErrors() #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.parallelwork"); + TEST_CASE("parallellwork.nowork") { std::atomic AbortFlag; @@ -255,6 +257,8 @@ TEST_CASE("parallellwork.limitqueue") Work.Wait(); } +TEST_SUITE_END(); + void parallellwork_forcelink() { diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp index a6a86ee12..f19afe715 100644 --- a/src/zencore/refcount.cpp +++ b/src/zencore/refcount.cpp @@ -33,6 +33,8 @@ refcount_forcelink() { } +TEST_SUITE_BEGIN("core.refcount"); + TEST_CASE("RefPtr") { RefPtr Ref; @@ -60,6 +62,8 @@ TEST_CASE("RefPtr") CHECK(IsDestroyed == true); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/sha1.cpp b/src/zencore/sha1.cpp index 3ee74d7d8..807ae4c30 100644 --- a/src/zencore/sha1.cpp +++ b/src/zencore/sha1.cpp @@ -373,6 +373,8 @@ sha1_forcelink() // return sha1text; // } +TEST_SUITE_BEGIN("core.sha1"); + TEST_CASE("SHA1") { uint8_t sha1_empty[20] = {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55, @@ -438,6 +440,8 @@ TEST_CASE("SHA1") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp index 78efb9d42..8dc6d49d8 100644 --- a/src/zencore/sharedbuffer.cpp +++ b/src/zencore/sharedbuffer.cpp @@ -152,10 +152,14 @@ sharedbuffer_forcelink() { } +TEST_SUITE_BEGIN("core.sharedbuffer"); + TEST_CASE("SharedBuffer") { } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp index a800ce121..de67303a4 100644 --- a/src/zencore/stream.cpp +++ b/src/zencore/stream.cpp @@ -79,6 +79,8 @@ BufferReader::Serialize(void* V, int64_t Length) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.stream"); + TEST_CASE("binary.writer.span") { BinaryWriter Writer; @@ -91,6 +93,8 @@ TEST_CASE("binary.writer.span") CHECK(memcmp(Result.GetData(), "apa banan", 9) == 0); } +TEST_SUITE_END(); + void stream_forcelink() { diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index ab1c7de58..27635a86c 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -546,6 +546,8 @@ UrlDecode(std::string_view InUrl) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.string"); + TEST_CASE("url") { using namespace std::literals; @@ -1222,6 +1224,8 @@ TEST_CASE("string") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 6000bd95c..0bae139bd 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -128,18 +128,24 @@ TestRunner::~TestRunner() { } +void +TestRunner::SetDefaultSuiteFilter(const char* Pattern) +{ + m_Impl->Session.setOption("test-suite", Pattern); +} + int -TestRunner::ApplyCommandLine(int argc, char const* const* argv) +TestRunner::ApplyCommandLine(int Argc, char const* const* Argv) { - m_Impl->Session.applyCommandLine(argc, argv); + m_Impl->Session.applyCommandLine(Argc, Argv); - for (int i = 1; i < argc; ++i) + for (int i = 1; i < Argc; ++i) { - if (argv[i] == "--debug"sv) + if (Argv[i] == "--debug"sv) { zen::logging::SetLogLevel(zen::logging::level::Debug); } - else if (argv[i] == "--verbose"sv) + else if (Argv[i] == "--verbose"sv) { zen::logging::SetLogLevel(zen::logging::level::Trace); } @@ -155,20 +161,20 @@ TestRunner::Run() } int -RunTestMain(int argc, char* argv[], [[maybe_unused]] const char* traceName, void (*forceLink)()) +RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink)()) { # if ZEN_PLATFORM_WINDOWS setlocale(LC_ALL, "en_us.UTF8"); # endif - forceLink(); + ForceLink(); # if ZEN_PLATFORM_LINUX zen::IgnoreChildSignals(); # endif # if ZEN_WITH_TRACE - zen::TraceInit(traceName); + zen::TraceInit(ExecutableName); zen::TraceOptions TraceCommandlineOptions; if (GetTraceOptionsFromCommandline(TraceCommandlineOptions)) { @@ -179,7 +185,30 @@ RunTestMain(int argc, char* argv[], [[maybe_unused]] const char* traceName, void zen::logging::InitializeLogging(); zen::MaximizeOpenFileCount(); - return ZEN_RUN_TESTS(argc, argv); + TestRunner Runner; + + // Derive default suite filter from ExecutableName: "zencore-test" -> "core.*" + if (ExecutableName) + { + std::string_view Name = ExecutableName; + if (Name.starts_with("zen")) + { + Name.remove_prefix(3); + } + if (Name.ends_with("-test")) + { + Name.remove_suffix(5); + } + if (!Name.empty()) + { + std::string Filter(Name); + Filter += ".*"; + Runner.SetDefaultSuiteFilter(Filter.c_str()); + } + } + + Runner.ApplyCommandLine(Argc, Argv); + return Runner.Run(); } } // namespace zen::testing diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp index d7636f2ad..971683721 100644 --- a/src/zencore/uid.cpp +++ b/src/zencore/uid.cpp @@ -156,6 +156,8 @@ Oid::FromMemory(const void* Ptr) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("core.uid"); + TEST_CASE("Oid") { SUBCASE("Basic") @@ -185,6 +187,8 @@ TEST_CASE("Oid") } } +TEST_SUITE_END(); + void uid_forcelink() { diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp index cb84bbe06..1cb338c66 100644 --- a/src/zencore/workthreadpool.cpp +++ b/src/zencore/workthreadpool.cpp @@ -354,6 +354,8 @@ workthreadpool_forcelink() using namespace std::literals; +TEST_SUITE_BEGIN("core.workthreadpool"); + TEST_CASE("threadpool.basic") { WorkerThreadPool Threadpool{1}; @@ -368,6 +370,8 @@ TEST_CASE("threadpool.basic") CHECK_THROWS(FutureThrow.get()); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index 4ff79edc7..d82474705 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -285,7 +285,7 @@ zencore_forcelinktests() namespace zen { -TEST_SUITE_BEGIN("core.assert"); +TEST_SUITE_BEGIN("core.zencore"); TEST_CASE("Assert.Default") { diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index c016e1c3c..248ae9d70 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -597,6 +597,8 @@ namespace testutil { } // namespace testutil +TEST_SUITE_BEGIN("http.httpclientcommon"); + TEST_CASE("BufferedReadFileStream") { ScopedTemporaryDirectory TmpDir; @@ -787,5 +789,7 @@ TEST_CASE("MultipartBoundaryParser") } } +TEST_SUITE_END(); + } // namespace zen #endif diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 1cfddb366..f94c58581 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -430,6 +430,8 @@ MeasureLatency(HttpClient& Client, std::string_view Url) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("responseformat") { using namespace std::literals; @@ -839,6 +841,8 @@ TEST_CASE("httpclient.password") AsioServer->RequestExit(); } } +TEST_SUITE_END(); + void httpclient_forcelink() { diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 91b1a3414..52bf149a7 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -257,6 +257,8 @@ struct TestServerFixture ////////////////////////////////////////////////////////////////////////// // Tests +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("httpclient.verbs") { TestServerFixture Fixture; @@ -1352,6 +1354,8 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) } } +TEST_SUITE_END(); + void httpclient_test_forcelink() { diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 3cefa0ad8..2facd8401 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -1322,6 +1322,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref v1 = ParseCommandLine("c:\\my\\exe.exe \"quoted arg\" \"one\",two,\"three\\\""); @@ -235,5 +237,7 @@ TEST_CASE("CommandLine") CHECK_EQ(v3Stripped[5], std::string("--build-part-name=win64")); } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp index 54f27dee7..28a0091cb 100644 --- a/src/zenutil/rpcrecording.cpp +++ b/src/zenutil/rpcrecording.cpp @@ -1119,7 +1119,7 @@ rpcrecord_forcelink() { } -TEST_SUITE_BEGIN("rpc.recording"); +TEST_SUITE_BEGIN("util.rpcrecording"); TEST_CASE("rpc.record") { diff --git a/src/zenutil/wildcard.cpp b/src/zenutil/wildcard.cpp index 7a44c0498..7f2f77780 100644 --- a/src/zenutil/wildcard.cpp +++ b/src/zenutil/wildcard.cpp @@ -118,6 +118,8 @@ wildcard_forcelink() { } +TEST_SUITE_BEGIN("util.wildcard"); + TEST_CASE("Wildcard") { CHECK(MatchWildcard("*.*", "normal.txt", true)); @@ -151,5 +153,7 @@ TEST_CASE("Wildcard") CHECK(MatchWildcard("*.d", "dir/path.d", true)); } +TEST_SUITE_END(); + #endif } // namespace zen -- cgit v1.2.3 From 1558b202663d9d18f87b384110891b190ad24ea2 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Tue, 3 Mar 2026 13:17:38 +0100 Subject: fix objectstore uri path parsing (#801) * add objectstore tests * in http router, for last matcher, test if it matches the remaining part of the uri --- src/zenhttp/httpserver.cpp | 87 +++++++++++++++++++++++++++----- src/zenserver-test/objectstore-tests.cpp | 74 +++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 14 deletions(-) create mode 100644 src/zenserver-test/objectstore-tests.cpp (limited to 'src') diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 2facd8401..d798c46d9 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -746,6 +746,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { if (UriPattern[i] == '}') { + if (i == PatternStart) + { + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); + } std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { @@ -911,8 +915,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) CapturedSegments.emplace_back(Uri); - for (int MatcherIndex : Matchers) + for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++) { + int MatcherIndex = Matchers[MatcherOffset]; if (UriPos >= UriLen) { IsMatch = false; @@ -922,9 +927,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (MatcherIndex < 0) { // Literal match - int LitIndex = -MatcherIndex - 1; - const std::string& LitStr = m_Literals[LitIndex]; - size_t LitLen = LitStr.length(); + int LitIndex = -MatcherIndex - 1; + std::string_view LitStr = m_Literals[LitIndex]; + size_t LitLen = LitStr.length(); if (Uri.substr(UriPos, LitLen) == LitStr) { @@ -940,9 +945,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { // Matcher function size_t SegmentStart = UriPos; - while (UriPos < UriLen && Uri[UriPos] != '/') + + if (MatcherOffset == (Matchers.size() - 1)) + { + // Last matcher, use the remaining part of the uri + UriPos = UriLen; + } + else { - ++UriPos; + while (UriPos < UriLen && Uri[UriPos] != '/') + { + ++UriPos; + } } std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); @@ -1429,20 +1443,33 @@ TEST_CASE("http.common") SUBCASE("router-matcher") { - bool HandledA = false; - bool HandledAA = false; - bool HandledAB = false; - bool HandledAandB = false; + bool HandledA = false; + bool HandledAA = false; + bool HandledAB = false; + bool HandledAandB = false; + bool HandledAandPath = false; std::vector Captures; auto Reset = [&] { - HandledA = HandledAA = HandledAB = HandledAandB = false; + HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false; Captures.clear(); }; TestHttpService Service; HttpRequestRouter r; - r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); - r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); + + r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; }); + r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; }); + static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); }); + + r.RegisterRoute( + "path/{a}/{path}", + [&](auto& Req) { + HandledAandPath = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + r.RegisterRoute( "{a}", [&](auto& Req) { @@ -1471,7 +1498,6 @@ TEST_CASE("http.common") Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); - { Reset(); TestHttpServerRequest req{Service, "ab"sv}; @@ -1479,6 +1505,7 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "ab"sv); @@ -1491,6 +1518,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1504,6 +1532,7 @@ TEST_CASE("http.common") CHECK(!HandledAA); CHECK(!HandledAB); CHECK(HandledAandB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1516,6 +1545,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); } { @@ -1525,6 +1555,35 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "a123"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "simple_path.txt"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "directory/and/path.txt"sv); } } diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp new file mode 100644 index 000000000..f3db5fdf6 --- /dev/null +++ b/src/zenserver-test/objectstore-tests.cpp @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS +# include "zenserver-test.h" +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::tests { + +using namespace std::literals; + +TEST_SUITE_BEGIN("server.objectstore"); + +TEST_CASE("objectstore.blobs") +{ + std::string_view Bucket = "bkt"sv; + + std::vector CompressedBlobsHashes; + std::vector BlobsSizes; + std::vector CompressedBlobsSizes; + { + ZenServerInstance Instance(TestEnv); + + const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--objectstore-enabled")); + CHECK(PortNumber != 0); + + HttpClient Client(Instance.GetBaseUri() + "/obj/"); + + for (size_t I = 0; I < 5; I++) + { + IoBuffer Blob = CreateSemiRandomBlob(4711 + I * 7); + BlobsSizes.push_back(Blob.GetSize()); + CompressedBuffer CompressedBlob = CompressedBuffer::Compress(SharedBuffer(std::move(Blob))); + CompressedBlobsHashes.push_back(CompressedBlob.DecodeRawHash()); + CompressedBlobsSizes.push_back(CompressedBlob.GetCompressedSize()); + IoBuffer Payload = std::move(CompressedBlob).GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + + std::string ObjectPath = fmt::format("{}/{}.utoc", + CompressedBlobsHashes.back().ToHexString().substr(0, 2), + CompressedBlobsHashes.back().ToHexString()); + + HttpClient::Response Result = Client.Put(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath), Payload); + CHECK(Result); + } + + for (size_t I = 0; I < 5; I++) + { + std::string ObjectPath = + fmt::format("{}/{}.utoc", CompressedBlobsHashes[I].ToHexString().substr(0, 2), CompressedBlobsHashes[I].ToHexString()); + HttpClient::Response Result = Client.Get(fmt::format("bucket/{}/{}.utoc", Bucket, ObjectPath)); + CHECK(Result); + CHECK_EQ(Result.ResponsePayload.GetSize(), CompressedBlobsSizes[I]); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(std::move(Result.ResponsePayload)), RawHash, RawSize); + CHECK(Compressed); + CHECK_EQ(RawHash, CompressedBlobsHashes[I]); + CHECK_EQ(RawSize, BlobsSizes[I]); + } + } +} + +TEST_SUITE_END(); + +} // namespace zen::tests +#endif -- cgit v1.2.3 From 463a0fde16b827c0ec44c9e88fe3c8c8098aa5ea Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Tue, 3 Mar 2026 20:49:01 +0100 Subject: use multi range requests (#800) - Improvement: `zen builds download` now uses multi-range requests for blocks to reduce download size - Improvement: `zen oplog-import` now uses partial block with multi-range requests for blocks to reduce download size - Improvement: Improved feedback in log/console during `zen oplog-import` - Improvement: `--allow-partial-block-requests` now defaults to `true` for `zen builds download` and `zen oplog-import` (was `mixed`) - Improvement: Improved range merging analysis when downloading partial blocks --- src/zen/cmds/builds_cmd.h | 2 +- src/zen/cmds/projectstore_cmd.h | 2 +- src/zen/cmds/workspaces_cmd.cpp | 2 +- src/zen/progressbar.cpp | 5 +- src/zenhttp/clients/httpclientcommon.cpp | 33 +- src/zenhttp/httpclient.cpp | 11 +- src/zenhttp/include/zenhttp/httpclient.h | 4 +- src/zenremotestore/builds/buildstoragecache.cpp | 72 +- .../builds/buildstorageoperations.cpp | 581 ++++++++++---- src/zenremotestore/builds/buildstorageutil.cpp | 17 - src/zenremotestore/builds/filebuildstorage.cpp | 39 + src/zenremotestore/builds/jupiterbuildstorage.cpp | 22 +- src/zenremotestore/chunking/chunkblock.cpp | 63 +- .../include/zenremotestore/builds/buildstorage.h | 21 +- .../zenremotestore/builds/buildstoragecache.h | 8 + .../zenremotestore/builds/buildstorageoperations.h | 11 +- .../zenremotestore/builds/buildstorageutil.h | 1 - .../include/zenremotestore/chunking/chunkblock.h | 31 +- .../zenremotestore/jupiter/jupitersession.h | 12 + .../projectstore/remoteprojectstore.h | 26 +- src/zenremotestore/jupiter/jupitersession.cpp | 65 ++ .../projectstore/buildsremoteprojectstore.cpp | 100 ++- .../projectstore/fileremoteprojectstore.cpp | 68 +- .../projectstore/jupiterremoteprojectstore.cpp | 60 +- .../projectstore/remoteprojectstore.cpp | 879 ++++++++++++--------- .../projectstore/zenremoteprojectstore.cpp | 145 ++-- .../storage/buildstore/httpbuildstore.cpp | 24 +- 27 files changed, 1582 insertions(+), 722 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.h b/src/zen/cmds/builds_cmd.h index f5c44ab55..5c80beed5 100644 --- a/src/zen/cmds/builds_cmd.h +++ b/src/zen/cmds/builds_cmd.h @@ -71,7 +71,7 @@ private: bool m_AppendNewContent = false; uint8_t m_BlockReuseMinPercentLimit = 85; bool m_AllowMultiparts = true; - std::string m_AllowPartialBlockRequests = "mixed"; + std::string m_AllowPartialBlockRequests = "true"; AuthCommandLineOptions m_AuthOptions; diff --git a/src/zen/cmds/projectstore_cmd.h b/src/zen/cmds/projectstore_cmd.h index 17fd76e9f..1ba98b39e 100644 --- a/src/zen/cmds/projectstore_cmd.h +++ b/src/zen/cmds/projectstore_cmd.h @@ -210,7 +210,7 @@ private: bool m_BoostWorkerMemory = false; bool m_BoostWorkers = false; - std::string m_AllowPartialBlockRequests = "mixed"; + std::string m_AllowPartialBlockRequests = "true"; }; class SnapshotOplogCommand : public ProjectStoreCommand diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 2661ac9da..af265d898 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -815,7 +815,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** if (Results.size() != m_ChunkIds.size()) { throw std::runtime_error( - fmt::format("failed to get workspace share batch - invalid result count recevied (expected: {}, received: {}", + fmt::format("failed to get workspace share batch - invalid result count received (expected: {}, received: {}", m_ChunkIds.size(), Results.size())); } diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 732f16e81..9467ed60d 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -207,8 +207,9 @@ ProgressBar::UpdateState(const State& NewState, bool DoLinebreak) size_t ProgressBarCount = (ProgressBarSize * PercentDone) / 100; uint64_t Completed = NewState.TotalCount - NewState.RemainingCount; uint64_t ETAElapsedMS = ElapsedTimeMS -= m_PausedMS; - uint64_t ETAMS = - (NewState.Status == State::EStatus::Running) && (PercentDone > 5) ? (ETAElapsedMS * NewState.RemainingCount) / Completed : 0; + uint64_t ETAMS = ((m_State.TotalCount == NewState.TotalCount) && (NewState.Status == State::EStatus::Running)) && (PercentDone > 5) + ? (ETAElapsedMS * NewState.RemainingCount) / Completed + : 0; uint32_t ConsoleColumns = TuiConsoleColumns(1024); diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 248ae9d70..9ded23375 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -394,31 +394,28 @@ namespace detail { { // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + size_t BoundaryEnd = std::string::npos; + while (!BoundaryName.empty() && BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } if (!BoundaryName.empty()) { - size_t BoundaryEnd = std::string::npos; - while (BoundaryName[0] == ' ') - { - BoundaryName = BoundaryName.substr(1); - } - if (!BoundaryName.empty()) + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') { - if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) { - BoundaryEnd = BoundaryName.find('"', 1); - if (BoundaryEnd != std::string::npos) - { - BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); - return true; - } - } - else - { - BoundaryEnd = BoundaryName.find_first_of(" \r\n"); - BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); return true; } } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } } } } diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index f94c58581..281d512cf 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -107,17 +107,14 @@ HttpClientBase::GetAccessToken() std::vector> HttpClient::Response::GetRanges(std::span> OffsetAndLengthPairs) const { - std::vector> Result; - Result.reserve(OffsetAndLengthPairs.size()); if (Ranges.empty()) { - for (const std::pair& Range : OffsetAndLengthPairs) - { - Result.emplace_back(std::make_pair(Range.first, Range.second)); - } - return Result; + return {}; } + std::vector> Result; + Result.reserve(OffsetAndLengthPairs.size()); + auto BoundaryIt = Ranges.begin(); auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index f00bbec63..53be36b9a 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -190,10 +190,12 @@ public: HttpContentType ContentType; }; - // Ranges will map out all recevied ranges, both single and multi-range responses + // Ranges will map out all received ranges, both single and multi-range responses // If no range was requested Ranges will be empty std::vector Ranges; + // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges). + // If the response was not a partial response, an empty vector will be returned std::vector> GetRanges(std::span> OffsetAndLengthPairs) const; // This contains any errors from the HTTP stack. It won't contain information on diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp index faa85f81b..53d33bd7e 100644 --- a/src/zenremotestore/builds/buildstoragecache.cpp +++ b/src/zenremotestore/builds/buildstoragecache.cpp @@ -151,7 +151,7 @@ public: auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); HttpClient::Response CacheResponse = - m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()), + m_HttpClient.Upload(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), Payload, ContentType); @@ -180,7 +180,7 @@ public: } CreateDirectories(m_TempFolderPath); HttpClient::Response CacheResponse = - m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash.ToHexString()), + m_HttpClient.Download(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), m_TempFolderPath, Headers); AddStatistic(CacheResponse); @@ -191,6 +191,74 @@ public: return {}; } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span> Ranges) override + { + ZEN_TRACE_CPU("ZenBuildStorageCache::GetBuildBlobRanges"); + + Stopwatch ExecutionTimer; + auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); + + CbObjectWriter Writer; + Writer.BeginArray("ranges"sv); + { + for (const std::pair& Range : Ranges) + { + Writer.BeginObject(); + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, Range.second); + } + Writer.EndObject(); + } + } + Writer.EndArray(); // ranges + + CreateDirectories(m_TempFolderPath); + HttpClient::Response CacheResponse = + m_HttpClient.Post(fmt::format("/builds/{}/{}/{}/blobs/{}", m_Namespace, m_Bucket, BuildId, RawHash), + Writer.Save(), + HttpClient::Accept(ZenContentType::kCbPackage)); + AddStatistic(CacheResponse); + if (CacheResponse.IsSuccess()) + { + CbPackage ResponsePackage = ParsePackageMessage(CacheResponse.ResponsePayload); + CbObjectView ResponseObject = ResponsePackage.GetObject(); + + CbArrayView RangeArray = ResponseObject["ranges"sv].AsArrayView(); + + std::vector> ReceivedRanges; + ReceivedRanges.reserve(RangeArray.Num()); + + uint64_t OffsetInPayloadRanges = 0; + + for (CbFieldView View : RangeArray) + { + CbObjectView RangeView = View.AsObjectView(); + uint64_t Offset = RangeView["offset"sv].AsUInt64(); + uint64_t Length = RangeView["length"sv].AsUInt64(); + + const std::pair& Range = Ranges[ReceivedRanges.size()]; + + if (Offset != Range.first || Length != Range.second) + { + return {}; + } + ReceivedRanges.push_back(std::make_pair(OffsetInPayloadRanges, Length)); + OffsetInPayloadRanges += Length; + } + + const CbAttachment* DataAttachment = ResponsePackage.FindAttachment(RawHash); + if (DataAttachment) + { + SharedBuffer PayloadRanges = DataAttachment->AsBinary(); + return BuildBlobRanges{.PayloadBuffer = PayloadRanges.AsIoBuffer(), .Ranges = std::move(ReceivedRanges)}; + } + } + return {}; + } + virtual void PutBlobMetadatas(const Oid& BuildId, std::span BlobHashes, std::span MetaDatas) override { ZEN_ASSERT(!IsFlushed); diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 5deb00707..f4b4d592b 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -38,6 +38,7 @@ ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_WITH_TESTS # include # include +# include # include #endif // ZEN_WITH_TESTS @@ -883,12 +884,14 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) BlobsExistsResult ExistsResult; { - ChunkBlockAnalyser BlockAnalyser(m_LogOutput, - m_BlockDescriptions, - ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, - .IsVerbose = m_Options.IsVerbose, - .HostLatencySec = m_Storage.BuildStorageLatencySec, - .HostHighSpeedLatencySec = m_Storage.CacheLatencySec}); + ChunkBlockAnalyser BlockAnalyser( + m_LogOutput, + m_BlockDescriptions, + ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, + .IsVerbose = m_Options.IsVerbose, + .HostLatencySec = m_Storage.BuildStorageLatencySec, + .HostHighSpeedLatencySec = m_Storage.CacheLatencySec, + .HostMaxRangeCountPerRequest = BuildStorageBase::MaxRangeCountPerRequest}); std::vector NeededBlocks = BlockAnalyser.GetNeeded( m_RemoteLookup.ChunkHashToChunkIndex, @@ -1027,15 +1030,13 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::All) { - BlockPartialDownloadModes.push_back(BlockExistInCache - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); + BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); } else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) { - BlockPartialDownloadModes.push_back(BlockExistInCache - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); + BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact + : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) { @@ -1045,6 +1046,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } } } + ZEN_ASSERT(BlockPartialDownloadModes.size() == m_BlockDescriptions.size()); ChunkBlockAnalyser::BlockResult PartialBlocks = @@ -1356,90 +1358,105 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) &Work, &PartialBlocks, BlockRangeStartIndex = BlockRangeIndex, - RangeCount](std::atomic&) { + RangeCount = RangeCount](std::atomic&) { if (!m_AbortFlag) { ZEN_TRACE_CPU("Async_GetPartialBlockRanges"); FilteredDownloadedBytesPerSecond.Start(); - for (size_t BlockRangeIndex = BlockRangeStartIndex; BlockRangeIndex < BlockRangeStartIndex + RangeCount; - BlockRangeIndex++) - { - ZEN_TRACE_CPU("GetPartialBlock"); - - const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = PartialBlocks.BlockRanges[BlockRangeIndex]; - - DownloadPartialBlock( - BlockRange, - ExistsResult, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalRequestCount, - TotalPartWriteCount, - &FilteredDownloadedBytesPerSecond, - &FilteredWrittenBytesPerSecond, - &BlockRange](IoBuffer&& InMemoryBuffer, const std::filesystem::path& OnDiskPath) { - if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } + DownloadPartialBlock( + PartialBlocks.BlockRanges, + BlockRangeStartIndex, + RangeCount, + ExistsResult, + [this, + &RemoteChunkIndexNeedsCopyFromSourceFlags, + &SequenceIndexChunksLeftToWriteCounters, + &WritePartsComplete, + &WriteCache, + &Work, + TotalRequestCount, + TotalPartWriteCount, + &FilteredDownloadedBytesPerSecond, + &FilteredWrittenBytesPerSecond, + &PartialBlocks](IoBuffer&& InMemoryBuffer, + const std::filesystem::path& OnDiskPath, + size_t BlockRangeStartIndex, + std::span> OffsetAndLengths) { + if (m_DownloadStats.RequestsCompleteCount == TotalRequestCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } - if (!m_AbortFlag) - { - Work.ScheduleWork( - m_IOWorkerPool, - [this, - &RemoteChunkIndexNeedsCopyFromSourceFlags, - &SequenceIndexChunksLeftToWriteCounters, - &WritePartsComplete, - &WriteCache, - &Work, - TotalPartWriteCount, - &FilteredWrittenBytesPerSecond, - &BlockRange, - BlockChunkPath = std::filesystem::path(OnDiskPath), - BlockPartialBuffer = std::move(InMemoryBuffer)](std::atomic&) mutable { - if (!m_AbortFlag) - { - ZEN_TRACE_CPU("Async_WritePartialBlock"); + if (!m_AbortFlag) + { + Work.ScheduleWork( + m_IOWorkerPool, + [this, + &RemoteChunkIndexNeedsCopyFromSourceFlags, + &SequenceIndexChunksLeftToWriteCounters, + &WritePartsComplete, + &WriteCache, + &Work, + TotalPartWriteCount, + &FilteredWrittenBytesPerSecond, + &PartialBlocks, + BlockRangeStartIndex, + BlockChunkPath = std::filesystem::path(OnDiskPath), + BlockPartialBuffer = std::move(InMemoryBuffer), + OffsetAndLengths = std::vector>(OffsetAndLengths.begin(), + OffsetAndLengths.end())]( + std::atomic&) mutable { + if (!m_AbortFlag) + { + ZEN_TRACE_CPU("Async_WritePartialBlock"); - const uint32_t BlockIndex = BlockRange.BlockIndex; + const uint32_t BlockIndex = PartialBlocks.BlockRanges[BlockRangeStartIndex].BlockIndex; - const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; + const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - if (BlockChunkPath.empty()) - { - ZEN_ASSERT(BlockPartialBuffer); - } - else + if (BlockChunkPath.empty()) + { + ZEN_ASSERT(BlockPartialBuffer); + } + else + { + ZEN_ASSERT(!BlockPartialBuffer); + BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); + if (!BlockPartialBuffer) { - ZEN_ASSERT(!BlockPartialBuffer); - BlockPartialBuffer = IoBufferBuilder::MakeFromFile(BlockChunkPath); - if (!BlockPartialBuffer) - { - throw std::runtime_error( - fmt::format("Could not open downloaded block {} from {}", - BlockDescription.BlockHash, - BlockChunkPath)); - } + throw std::runtime_error( + fmt::format("Could not open downloaded block {} from {}", + BlockDescription.BlockHash, + BlockChunkPath)); } + } + + FilteredWrittenBytesPerSecond.Start(); - FilteredWrittenBytesPerSecond.Start(); - - if (!WritePartialBlockChunksToCache( - BlockDescription, - SequenceIndexChunksLeftToWriteCounters, - Work, - CompositeBuffer(std::move(BlockPartialBuffer)), - BlockRange.ChunkBlockIndexStart, - BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount - 1, - RemoteChunkIndexNeedsCopyFromSourceFlags, - WriteCache)) + size_t RangeCount = OffsetAndLengths.size(); + + for (size_t PartialRangeIndex = 0; PartialRangeIndex < RangeCount; PartialRangeIndex++) + { + const std::pair& OffsetAndLength = + OffsetAndLengths[PartialRangeIndex]; + IoBuffer BlockRangeBuffer(BlockPartialBuffer, + OffsetAndLength.first, + OffsetAndLength.second); + + const ChunkBlockAnalyser::BlockRangeDescriptor& RangeDescriptor = + PartialBlocks.BlockRanges[BlockRangeStartIndex + PartialRangeIndex]; + + if (!WritePartialBlockChunksToCache(BlockDescription, + SequenceIndexChunksLeftToWriteCounters, + Work, + CompositeBuffer(std::move(BlockRangeBuffer)), + RangeDescriptor.ChunkBlockIndexStart, + RangeDescriptor.ChunkBlockIndexStart + + RangeDescriptor.ChunkBlockIndexCount - 1, + RemoteChunkIndexNeedsCopyFromSourceFlags, + WriteCache)) { std::error_code DummyEc; RemoveFile(BlockChunkPath, DummyEc); @@ -1447,28 +1464,27 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) fmt::format("Partial block {} is malformed", BlockDescription.BlockHash)); } - std::error_code Ec = TryRemoveFile(BlockChunkPath); - if (Ec) - { - ZEN_OPERATION_LOG_DEBUG(m_LogOutput, - "Failed removing file '{}', reason: ({}) {}", - BlockChunkPath, - Ec.value(), - Ec.message()); - } - WritePartsComplete++; if (WritePartsComplete == TotalPartWriteCount) { FilteredWrittenBytesPerSecond.Stop(); } } - }, - OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog - : WorkerThreadPool::EMode::EnableBacklog); - } - }); - } + std::error_code Ec = TryRemoveFile(BlockChunkPath); + if (Ec) + { + ZEN_OPERATION_LOG_DEBUG(m_LogOutput, + "Failed removing file '{}', reason: ({}) {}", + BlockChunkPath, + Ec.value(), + Ec.message()); + } + } + }, + OnDiskPath.empty() ? WorkerThreadPool::EMode::DisableBacklog + : WorkerThreadPool::EMode::EnableBacklog); + } + }); } }); BlockRangeIndex += RangeCount; @@ -3161,45 +3177,40 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde void BuildsOperationUpdateFolder::DownloadPartialBlock( - const ChunkBlockAnalyser::BlockRangeDescriptor BlockRange, - const BlobsExistsResult& ExistsResult, - std::function&& OnDownloaded) + std::span BlockRanges, + size_t BlockRangeStartIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + std::function> OffsetAndLengths)>&& OnDownloaded) { - const uint32_t BlockIndex = BlockRange.BlockIndex; + const uint32_t BlockIndex = BlockRanges[BlockRangeStartIndex].BlockIndex; const ChunkBlockDescription& BlockDescription = m_BlockDescriptions[BlockIndex]; - IoBuffer BlockBuffer; - if (m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash)) - { - BlockBuffer = - m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - } - if (!BlockBuffer) - { - BlockBuffer = - m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - } - if (!BlockBuffer) - { - throw std::runtime_error(fmt::format("Block {} is missing when fetching range {} -> {}", - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeStart + BlockRange.RangeLength)); - } - if (!m_AbortFlag) - { - uint64_t BlockSize = BlockBuffer.GetSize(); + auto ProcessDownload = [this]( + const ChunkBlockDescription& BlockDescription, + IoBuffer&& BlockRangeBuffer, + size_t BlockRangeStartIndex, + std::span> BlockOffsetAndLengths, + const std::function> OffsetAndLengths)>& OnDownloaded) { + uint64_t BlockRangeBufferSize = BlockRangeBuffer.GetSize(); m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += BlockSize; - m_DownloadStats.RequestsCompleteCount++; + m_DownloadStats.DownloadedBlockByteCount += BlockRangeBufferSize; + m_DownloadStats.RequestsCompleteCount += BlockOffsetAndLengths.size(); std::filesystem::path BlockChunkPath; // Check if the dowloaded block is file based and we can move it directly without rewriting it { IoBufferFileReference FileRef; - if (BlockBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && (FileRef.FileChunkSize == BlockSize)) + if (BlockRangeBuffer.GetFileReference(FileRef) && (FileRef.FileChunkOffset == 0) && + (FileRef.FileChunkSize == BlockRangeBufferSize)) { ZEN_TRACE_CPU("MoveTempPartialBlock"); @@ -3207,10 +3218,17 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( std::filesystem::path TempBlobPath = PathFromHandle(FileRef.FileHandle, Ec); if (!Ec) { - BlockBuffer.SetDeleteOnClose(false); - BlockBuffer = {}; - BlockChunkPath = m_TempBlockFolderPath / - fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); + BlockRangeBuffer.SetDeleteOnClose(false); + BlockRangeBuffer = {}; + + IoHashStream RangeId; + for (const std::pair& Range : BlockOffsetAndLengths) + { + RangeId.Append(&Range.first, sizeof(uint64_t)); + RangeId.Append(&Range.second, sizeof(uint64_t)); + } + + BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); RenameFile(TempBlobPath, BlockChunkPath, Ec); if (Ec) { @@ -3218,27 +3236,137 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( // Re-open the temp file again BasicFile OpenTemp(TempBlobPath, BasicFile::Mode::kDelete); - BlockBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockSize, true); - BlockBuffer.SetDeleteOnClose(true); + BlockRangeBuffer = IoBuffer(IoBuffer::File, OpenTemp.Detach(), 0, BlockRangeBufferSize, true); + BlockRangeBuffer.SetDeleteOnClose(true); } } } } - if (BlockChunkPath.empty() && (BlockSize > m_Options.MaximumInMemoryPayloadSize)) + if (BlockChunkPath.empty() && (BlockRangeBufferSize > m_Options.MaximumInMemoryPayloadSize)) { ZEN_TRACE_CPU("WriteTempPartialBlock"); + + IoHashStream RangeId; + for (const std::pair& Range : BlockOffsetAndLengths) + { + RangeId.Append(&Range.first, sizeof(uint64_t)); + RangeId.Append(&Range.second, sizeof(uint64_t)); + } + // Could not be moved and rather large, lets store it on disk - BlockChunkPath = m_TempBlockFolderPath / - fmt::format("{}_{:x}_{:x}", BlockDescription.BlockHash, BlockRange.RangeStart, BlockRange.RangeLength); - TemporaryFile::SafeWriteFile(BlockChunkPath, BlockBuffer); - BlockBuffer = {}; + BlockChunkPath = m_TempBlockFolderPath / fmt::format("{}_{}", BlockDescription.BlockHash, RangeId.GetHash()); + TemporaryFile::SafeWriteFile(BlockChunkPath, BlockRangeBuffer); + BlockRangeBuffer = {}; } if (!m_AbortFlag) { - OnDownloaded(std::move(BlockBuffer), std::move(BlockChunkPath)); + OnDownloaded(std::move(BlockRangeBuffer), std::move(BlockChunkPath), BlockRangeStartIndex, BlockOffsetAndLengths); + } + }; + + std::vector> Ranges; + Ranges.reserve(BlockRangeCount); + for (size_t BlockRangeIndex = BlockRangeStartIndex; BlockRangeIndex < BlockRangeStartIndex + BlockRangeCount; BlockRangeIndex++) + { + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRanges[BlockRangeIndex]; + Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); + } + + if (m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash)) + { + BuildStorageCache::BuildBlobRanges RangeBuffers = + m_Storage.BuildCacheStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, Ranges); + if (RangeBuffers.PayloadBuffer) + { + if (!m_AbortFlag) + { + if (RangeBuffers.Ranges.size() != Ranges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + Ranges.size(), + BlockDescription.BlockHash, + RangeBuffers.Ranges.size())); + } + + std::vector> BlockOffsetAndLengths = std::move(RangeBuffers.Ranges); + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + BlockRangeStartIndex, + BlockOffsetAndLengths, + OnDownloaded); + } + return; } } + + const size_t MaxRangesPerRequestToJupiter = BuildStorageBase::MaxRangeCountPerRequest; + + size_t SubBlockRangeCount = BlockRangeCount; + size_t SubRangeCountComplete = 0; + std::span> RangesSpan(Ranges); + while (SubRangeCountComplete < SubBlockRangeCount) + { + if (m_AbortFlag) + { + break; + } + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, MaxRangesPerRequestToJupiter); + size_t SubRangeStartIndex = BlockRangeStartIndex + SubRangeCountComplete; + + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + BuildStorageBase::BuildBlobRanges RangeBuffers = + m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); + if (RangeBuffers.PayloadBuffer) + { + if (m_AbortFlag) + { + break; + } + if (RangeBuffers.Ranges.empty()) + { + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Upload to cache (if enabled) and use the whole payload for the remaining ranges + + if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + { + m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector{RangeBuffers.PayloadBuffer})); + } + + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + OnDownloaded); + } + else + { + if (RangeBuffers.Ranges.size() != SubRanges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + SubRanges.size(), + BlockDescription.BlockHash, + RangeBuffers.Ranges.size())); + } + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + OnDownloaded); + } + } + else + { + throw std::runtime_error(fmt::format("Block {} is missing when fetching {} ranges", BlockDescription.BlockHash, SubRangeCount)); + } + + SubRangeCountComplete += SubRangeCount; + } } std::vector @@ -7083,16 +7211,31 @@ GetRemoteContent(OperationLogOutput& Output, // TODO: GetBlockDescriptions for all BlockRawHashes in one go - check for local block descriptions when we cache them { + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size()); + } + + Stopwatch GetBlockMetadataTimer; + bool AttemptFallback = false; OutBlockDescriptions = GetBlockDescriptions(Output, *Storage.BuildStorage, Storage.BuildCacheStorage.get(), BuildId, - BuildPartId, BlockRawHashes, AttemptFallback, IsQuiet, IsVerbose); + + if (!IsQuiet) + { + ZEN_OPERATION_LOG_INFO(Output, + "GetBlockMetadata for {} took {}. Found {} blocks", + BuildPartId, + NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), + OutBlockDescriptions.size()); + } } CalculateLocalChunkOrders(AbsoluteChunkOrders, @@ -7935,6 +8078,164 @@ TEST_CASE("buildstorageoperations.upload.multipart") } } +TEST_CASE("buildstorageoperations.partial.block.download" * doctest::skip(true)) +{ + const std::string OidcExecutableName = "OidcToken" ZEN_EXE_SUFFIX_LITERAL; + std::filesystem::path OidcTokenExePath = (GetRunningExecutablePath().parent_path() / OidcExecutableName).make_preferred(); + + HttpClientSettings ClientSettings{ + .LogCategory = "httpbuildsclient", + .AccessTokenProvider = + httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, "https://jupiter.devtools.epicgames.com", true, false, false), + .AssumeHttp2 = false, + .AllowResume = true, + .RetryCount = 0, + .Verbose = false}; + + HttpClient HttpClient("https://euc.jupiter.devtools.epicgames.com", ClientSettings); + + const std::string_view Namespace = "fortnite.oplog"; + const std::string_view Bucket = "fortnitegame.staged-build.fortnite-main.ps4-client"; + const Oid BuildId = Oid::FromHexString("09a76ea92ad301d4724fafad"); + + { + HttpClient::Response Response = HttpClient.Get(fmt::format("/api/v2/builds/{}/{}/{}", Namespace, Bucket, BuildId), + HttpClient::Accept(ZenContentType::kCbObject)); + CbValidateError ValidateResult = CbValidateError::None; + CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(Response.ResponsePayload), ValidateResult); + REQUIRE(ValidateResult == CbValidateError::None); + } + + std::vector BlockDescriptions; + { + CbObjectWriter Request; + + Request.BeginArray("blocks"sv); + { + Request.AddHash(IoHash::FromHexString("7c353ed782675a5e8f968e61e51fc797ecdc2882")); + } + Request.EndArray(); + + IoBuffer Payload = Request.Save().GetBuffer().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCbObject); + + HttpClient::Response BlockDescriptionsResponse = + HttpClient.Post(fmt::format("/api/v2/builds/{}/{}/{}/blocks/getBlockMetadata", Namespace, Bucket, BuildId), + Payload, + HttpClient::Accept(ZenContentType::kCbObject)); + REQUIRE(BlockDescriptionsResponse.IsSuccess()); + + CbValidateError ValidateResult = CbValidateError::None; + CbObject Object = ValidateAndReadCompactBinaryObject(IoBuffer(BlockDescriptionsResponse.ResponsePayload), ValidateResult); + REQUIRE(ValidateResult == CbValidateError::None); + + { + CbArrayView BlocksArray = Object["blocks"sv].AsArrayView(); + for (CbFieldView Block : BlocksArray) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(Block.AsObjectView()); + BlockDescriptions.emplace_back(std::move(Description)); + } + } + } + + REQUIRE(!BlockDescriptions.empty()); + + const IoHash BlockHash = BlockDescriptions.back().BlockHash; + + const ChunkBlockDescription& BlockDescription = BlockDescriptions.front(); + REQUIRE(!BlockDescription.ChunkRawHashes.empty()); + REQUIRE(!BlockDescription.ChunkCompressedLengths.empty()); + + std::vector> ChunkOffsetAndSizes; + uint64_t Offset = gsl::narrow(CompressedBuffer::GetHeaderSizeForNoneEncoder() + BlockDescription.HeaderSize); + + for (uint32_t ChunkCompressedSize : BlockDescription.ChunkCompressedLengths) + { + ChunkOffsetAndSizes.push_back(std::make_pair(Offset, ChunkCompressedSize)); + Offset += ChunkCompressedSize; + } + + ScopedTemporaryDirectory SourceFolder; + + auto Validate = [&](std::span ChunkIndexesToFetch) { + std::vector> Ranges; + for (uint32_t ChunkIndex : ChunkIndexesToFetch) + { + Ranges.push_back(ChunkOffsetAndSizes[ChunkIndex]); + } + + HttpClient::KeyValueMap Headers; + if (!Ranges.empty()) + { + ExtendableStringBuilder<512> SB; + for (const std::pair& R : Ranges) + { + if (SB.Size() > 0) + { + SB << ", "; + } + SB << R.first << "-" << R.first + R.second - 1; + } + Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())}); + } + + HttpClient::Response GetBlobRangesResponse = HttpClient.Download( + fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect=false", Namespace, Bucket, BuildId, BlockHash), + SourceFolder.Path(), + Headers); + + REQUIRE(GetBlobRangesResponse.IsSuccess()); + MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); + + std::vector> PayloadRanges = GetBlobRangesResponse.GetRanges(Ranges); + if (PayloadRanges.empty()) + { + // We got the whole blob, use the ranges as is + PayloadRanges = Ranges; + } + + REQUIRE(PayloadRanges.size() == Ranges.size()); + + for (uint32_t RangeIndex = 0; RangeIndex < PayloadRanges.size(); RangeIndex++) + { + const std::pair& PayloadRange = PayloadRanges[RangeIndex]; + + CHECK_EQ(PayloadRange.second, Ranges[RangeIndex].second); + + IoBuffer ChunkPayload(GetBlobRangesResponse.ResponsePayload, PayloadRange.first, PayloadRange.second); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(ChunkPayload), RawHash, RawSize); + CHECK(CompressedChunk); + CHECK_EQ(RawHash, BlockDescription.ChunkRawHashes[ChunkIndexesToFetch[RangeIndex]]); + CHECK_EQ(RawSize, BlockDescription.ChunkRawLengths[ChunkIndexesToFetch[RangeIndex]]); + } + }; + + { + // Single + std::vector ChunkIndexesToFetch{uint32_t(BlockDescription.ChunkCompressedLengths.size() / 2)}; + Validate(ChunkIndexesToFetch); + } + { + // Many + std::vector ChunkIndexesToFetch; + for (uint32_t Index = 0; Index < BlockDescription.ChunkCompressedLengths.size() / 16; Index++) + { + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7)); + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 1)); + ChunkIndexesToFetch.push_back(uint32_t(BlockDescription.ChunkCompressedLengths.size() / 6 + Index * 7 + 3)); + } + Validate(ChunkIndexesToFetch); + } + + { + // First and last + std::vector ChunkIndexesToFetch{0, uint32_t(BlockDescription.ChunkCompressedLengths.size() - 1)}; + Validate(ChunkIndexesToFetch); + } +} TEST_SUITE_END(); void diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp index b249d7d52..d65f18b9a 100644 --- a/src/zenremotestore/builds/buildstorageutil.cpp +++ b/src/zenremotestore/builds/buildstorageutil.cpp @@ -251,7 +251,6 @@ GetBlockDescriptions(OperationLogOutput& Output, BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, - const Oid& BuildPartId, std::span BlockRawHashes, bool AttemptFallback, bool IsQuiet, @@ -259,13 +258,6 @@ GetBlockDescriptions(OperationLogOutput& Output, { using namespace std::literals; - if (!IsQuiet) - { - ZEN_OPERATION_LOG_INFO(Output, "Fetching metadata for {} blocks", BlockRawHashes.size()); - } - - Stopwatch GetBlockMetadataTimer; - std::vector UnorderedList; tsl::robin_map BlockDescriptionLookup; if (OptionalCacheStorage && !BlockRawHashes.empty()) @@ -355,15 +347,6 @@ GetBlockDescriptions(OperationLogOutput& Output, } } - if (!IsQuiet) - { - ZEN_OPERATION_LOG_INFO(Output, - "GetBlockMetadata for {} took {}. Found {} blocks", - BuildPartId, - NiceTimeSpanMs(GetBlockMetadataTimer.GetElapsedTimeMs()), - Result.size()); - } - if (Result.size() != BlockRawHashes.size()) { std::string ErrorDescription = diff --git a/src/zenremotestore/builds/filebuildstorage.cpp b/src/zenremotestore/builds/filebuildstorage.cpp index 55e69de61..2f4904449 100644 --- a/src/zenremotestore/builds/filebuildstorage.cpp +++ b/src/zenremotestore/builds/filebuildstorage.cpp @@ -432,6 +432,45 @@ public: return IoBuffer{}; } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span> Ranges) override + { + ZEN_TRACE_CPU("FileBuildStorage::GetBuildBlobRanges"); + ZEN_UNUSED(BuildId); + ZEN_ASSERT(!Ranges.empty()); + + uint64_t ReceivedBytes = 0; + uint64_t SentBytes = Ranges.size() * 2 * 8; + + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer, SentBytes, ReceivedBytes); }); + + BuildBlobRanges Result; + + const std::filesystem::path BlockPath = GetBlobPayloadPath(RawHash); + if (IsFile(BlockPath)) + { + BasicFile File(BlockPath, BasicFile::Mode::kRead); + + uint64_t RangeOffset = Ranges.front().first; + uint64_t RangeBytes = Ranges.back().first + Ranges.back().second - RangeOffset; + Result.PayloadBuffer = IoBufferBuilder::MakeFromFileHandle(File.Detach(), RangeOffset, RangeBytes); + + Result.Ranges.reserve(Ranges.size()); + + for (const std::pair& Range : Ranges) + { + Result.Ranges.push_back(std::make_pair(Range.first - RangeOffset, Range.second)); + } + ReceivedBytes = Result.PayloadBuffer.GetSize(); + } + return Result; + } + virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp index 23d0ddd4c..8e16da1a9 100644 --- a/src/zenremotestore/builds/jupiterbuildstorage.cpp +++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp @@ -21,7 +21,7 @@ namespace zen { using namespace std::literals; namespace { - void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix) + [[noreturn]] void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix) { int Error = Result.ErrorCode < (int)HttpResponseCode::Continue ? Result.ErrorCode : 0; HttpResponseCode Status = @@ -295,6 +295,26 @@ public: return std::move(GetBuildBlobResult.Response); } + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span> Ranges) override + { + ZEN_TRACE_CPU("Jupiter::GetBuildBlob"); + + Stopwatch ExecutionTimer; + auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); + CreateDirectories(m_TempFolderPath); + + BuildBlobRangesResult GetBuildBlobResult = + m_Session.GetBuildBlob(m_Namespace, m_Bucket, BuildId, RawHash, m_TempFolderPath, Ranges); + AddStatistic(GetBuildBlobResult); + if (!GetBuildBlobResult.Success) + { + ThrowFromJupiterResult(GetBuildBlobResult, "Failed fetching build blob ranges"sv); + } + return BuildBlobRanges{.PayloadBuffer = std::move(GetBuildBlobResult.Response), .Ranges = std::move(GetBuildBlobResult.Ranges)}; + } + virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index 3a4e6011d..9c3fe8a0b 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -608,40 +608,49 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span if (PartialBlockDownloadMode != EPartialBlockDownloadMode::Exact && BlockRanges.size() > 1) { - // TODO: Once we have support in our http client to request multiple ranges in one request this - // logic would need to change as the per-request overhead would go away + const uint64_t MaxRangeCountPerRequest = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed + ? m_Options.HostHighSpeedMaxRangeCountPerRequest + : m_Options.HostMaxRangeCountPerRequest; - const double LatencySec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed - ? m_Options.HostHighSpeedLatencySec - : m_Options.HostLatencySec; - if (LatencySec > 0) + ZEN_ASSERT(MaxRangeCountPerRequest != 0); + + if (MaxRangeCountPerRequest != (uint64_t)-1) { - const uint64_t BytesPerSec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed - ? m_Options.HostHighSpeedBytesPerSec - : m_Options.HostSpeedBytesPerSec; + const uint64_t ExtraRequestCount = BlockRanges.size() / MaxRangeCountPerRequest; - const double ExtraRequestTimeSec = (BlockRanges.size() - 1) * LatencySec; - const uint64_t ExtraRequestTimeBytes = uint64_t(ExtraRequestTimeSec * BytesPerSec); + const double LatencySec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed + ? m_Options.HostHighSpeedLatencySec + : m_Options.HostLatencySec; + if (LatencySec > 0) + { + const uint64_t BytesPerSec = PartialBlockDownloadMode == EPartialBlockDownloadMode::MultiRangeHighSpeed + ? m_Options.HostHighSpeedBytesPerSec + : m_Options.HostSpeedBytesPerSec; - const uint64_t FullRangeSize = - BlockRanges.back().RangeStart + BlockRanges.back().RangeLength - BlockRanges.front().RangeStart; + const double ExtraRequestTimeSec = ExtraRequestCount * LatencySec; + const uint64_t ExtraRequestTimeBytes = uint64_t(ExtraRequestTimeSec * BytesPerSec); - if (ExtraRequestTimeBytes + RequestedSize >= FullRangeSize) - { - BlockRanges = std::vector{MergeBlockRanges(BlockRanges)}; + const uint64_t FullRangeSize = + BlockRanges.back().RangeStart + BlockRanges.back().RangeLength - BlockRanges.front().RangeStart; - if (m_Options.IsVerbose) + if (ExtraRequestTimeBytes + RequestedSize >= FullRangeSize) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, - "Merging {} chunks ({}) from block {} ({}) to single request (extra bytes {})", - NeededBlock.ChunkIndexes.size(), - NiceBytes(RequestedSize), - BlockDescription.BlockHash, - NiceBytes(TotalBlockSize), - NiceBytes(BlockRanges.front().RangeLength - RequestedSize)); + BlockRanges = std::vector{MergeBlockRanges(BlockRanges)}; + + if (m_Options.IsVerbose) + { + ZEN_OPERATION_LOG_INFO( + m_LogOutput, + "Merging {} chunks ({}) from block {} ({}) to single request (extra bytes {})", + NeededBlock.ChunkIndexes.size(), + NiceBytes(RequestedSize), + BlockDescription.BlockHash, + NiceBytes(TotalBlockSize), + NiceBytes(BlockRanges.front().RangeLength - RequestedSize)); + } + + RequestedSize = BlockRanges.front().RangeLength; } - - RequestedSize = BlockRanges.front().RangeLength; } } } @@ -730,7 +739,7 @@ ChunkBlockAnalyser::CalculatePartialBlockDownloads(std::span ZEN_OPERATION_LOG_INFO(m_LogOutput, "Analysis of partial block requests saves download of {} out of {}, {:.1f}% of possible {} using {} extra " - "requests. Completed in {}", + "ranges. Completed in {}", NiceBytes(ActualSkippedSize), NiceBytes(AllBlocksTotalBlocksSize), PercentOfIdealPartialSkippedSize, diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h index 85dabc59f..ce3da41c1 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h @@ -53,15 +53,26 @@ public: std::function&& Transmitter, std::function&& OnSentBytes) = 0; - virtual IoBuffer GetBuildBlob(const Oid& BuildId, - const IoHash& RawHash, - uint64_t RangeOffset = 0, - uint64_t RangeBytes = (uint64_t)-1) = 0; + virtual IoBuffer GetBuildBlob(const Oid& BuildId, + const IoHash& RawHash, + uint64_t RangeOffset = 0, + uint64_t RangeBytes = (uint64_t)-1) = 0; + + static constexpr size_t MaxRangeCountPerRequest = 128u; + + struct BuildBlobRanges + { + IoBuffer PayloadBuffer; + std::vector> Ranges; + }; + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span> Ranges) = 0; virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t ChunkSize, std::function&& OnReceive, - std::function&& OnComplete) = 0; + std::function&& OnComplete) = 0; [[nodiscard]] virtual bool PutBlockMetadata(const Oid& BuildId, const IoHash& BlockRawHash, const CbObject& MetaData) = 0; virtual CbObject FindBlocks(const Oid& BuildId, uint64_t MaxBlockCount) = 0; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h index f25ce5b5e..67c93480b 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h @@ -37,6 +37,14 @@ public: const IoHash& RawHash, uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) = 0; + struct BuildBlobRanges + { + IoBuffer PayloadBuffer; + std::vector> Ranges; + }; + virtual BuildBlobRanges GetBuildBlobRanges(const Oid& BuildId, + const IoHash& RawHash, + std::span> Ranges) = 0; virtual void PutBlobMetadatas(const Oid& BuildId, std::span BlobHashes, std::span MetaDatas) = 0; virtual std::vector GetBlobMetadatas(const Oid& BuildId, std::span BlobHashes) = 0; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 31733569e..875b8593b 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -263,9 +263,14 @@ private: ParallelWork& Work, std::function&& OnDownloaded); - void DownloadPartialBlock(const ChunkBlockAnalyser::BlockRangeDescriptor BlockRange, - const BlobsExistsResult& ExistsResult, - std::function&& OnDownloaded); + void DownloadPartialBlock(std::span BlockRanges, + size_t BlockRangeIndex, + size_t BlockRangeCount, + const BlobsExistsResult& ExistsResult, + std::function> OffsetAndLengths)>&& OnDownloaded); std::vector WriteLocalChunkToCache(CloneQueryInterface* CloneQuery, const CopyChunkData& CopyData, diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h index 4b85d8f1e..764a24e61 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h @@ -45,7 +45,6 @@ std::vector GetBlockDescriptions(OperationLogOutput& Out BuildStorageBase& Storage, BuildStorageCache* OptionalCacheStorage, const Oid& BuildId, - const Oid& BuildPartId, std::span BlockRawHashes, bool AttemptFallback, bool IsQuiet, diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index 5a17ef79c..7aae1442e 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -82,12 +82,14 @@ class ChunkBlockAnalyser public: struct Options { - bool IsQuiet = false; - bool IsVerbose = false; - double HostLatencySec = -1.0; - double HostHighSpeedLatencySec = -1.0; - uint64_t HostSpeedBytesPerSec = (1u * 1024u * 1024u * 1024u) / 8u; // 1GBit - uint64_t HostHighSpeedBytesPerSec = (2u * 1024u * 1024u * 1024u) / 8u; // 2GBit + bool IsQuiet = false; + bool IsVerbose = false; + double HostLatencySec = -1.0; + double HostHighSpeedLatencySec = -1.0; + uint64_t HostSpeedBytesPerSec = (1u * 1024u * 1024u * 1024u) / 8u; // 1GBit + uint64_t HostHighSpeedBytesPerSec = (2u * 1024u * 1024u * 1024u) / 8u; // 2GBit + uint64_t HostMaxRangeCountPerRequest = (uint64_t)-1; + uint64_t HostHighSpeedMaxRangeCountPerRequest = (uint64_t)-1; // No limit }; ChunkBlockAnalyser(OperationLogOutput& LogOutput, std::span BlockDescriptions, const Options& Options); @@ -137,14 +139,15 @@ private: static constexpr uint16_t FullBlockRangePercentLimit = 98; - static constexpr BlockRangeLimit ForceMergeLimits[] = {{.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 1}, - {.SizePercent = 90, .MaxRangeCount = 4}, - {.SizePercent = 85, .MaxRangeCount = 16}, - {.SizePercent = 80, .MaxRangeCount = 32}, - {.SizePercent = 75, .MaxRangeCount = 48}, - {.SizePercent = 70, .MaxRangeCount = 64}, - {.SizePercent = 4, .MaxRangeCount = 82}, - {.SizePercent = 0, .MaxRangeCount = 96}}; + static constexpr BlockRangeLimit ForceMergeLimits[] = {{.SizePercent = FullBlockRangePercentLimit, .MaxRangeCount = 8}, + {.SizePercent = 90, .MaxRangeCount = 16}, + {.SizePercent = 85, .MaxRangeCount = 32}, + {.SizePercent = 80, .MaxRangeCount = 48}, + {.SizePercent = 75, .MaxRangeCount = 64}, + {.SizePercent = 70, .MaxRangeCount = 92}, + {.SizePercent = 50, .MaxRangeCount = 128}, + {.SizePercent = 4, .MaxRangeCount = 192}, + {.SizePercent = 0, .MaxRangeCount = 256}}; BlockRangeDescriptor MergeBlockRanges(std::span Ranges); std::optional> MakeOptionalBlockRangeVector(uint64_t TotalBlockSize, diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h index eaf6962fd..8721bc37f 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupitersession.h @@ -56,6 +56,11 @@ struct FinalizeBuildPartResult : JupiterResult std::vector Needs; }; +struct BuildBlobRangesResult : JupiterResult +{ + std::vector> Ranges; +}; + /** * Context for performing Jupiter operations * @@ -135,6 +140,13 @@ public: uint64_t Offset = 0, uint64_t Size = (uint64_t)-1); + BuildBlobRangesResult GetBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + std::filesystem::path TempFolderPath, + std::span> Ranges); + JupiterResult PutMultipartBuildBlob(std::string_view Namespace, std::string_view BucketId, const Oid& BuildId, diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 152c02ee2..2cf10c664 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -84,6 +84,12 @@ public: std::vector HasBody; }; + struct LoadAttachmentRangesResult : public Result + { + IoBuffer Bytes; + std::vector> Ranges; + }; + struct RemoteStoreInfo { bool CreateBlocks; @@ -127,15 +133,21 @@ public: virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) = 0; virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) = 0; - struct AttachmentRange + enum ESourceMode { - uint64_t Offset = 0; - uint64_t Bytes = (uint64_t)-1; - - inline operator bool() const { return Offset != 0 || Bytes != (uint64_t)-1; } + kAny, + kCacheOnly, + kHostOnly }; - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) = 0; - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) = 0; + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode = ESourceMode::kAny) = 0; + + static constexpr size_t MaxRangeCountPerRequest = 128u; + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span> Ranges, + ESourceMode SourceMode = ESourceMode::kAny) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode = ESourceMode::kAny) = 0; virtual void Flush() = 0; }; diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index 1bc6564ce..52f9eb678 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -852,6 +852,71 @@ JupiterSession::GetBuildBlob(std::string_view Namespace, return detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv); } +BuildBlobRangesResult +JupiterSession::GetBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + std::filesystem::path TempFolderPath, + std::span> Ranges) +{ + HttpClient::KeyValueMap Headers; + if (!Ranges.empty()) + { + ExtendableStringBuilder<512> SB; + for (const std::pair& R : Ranges) + { + if (SB.Size() > 0) + { + SB << ", "; + } + SB << R.first << "-" << R.first + R.second - 1; + } + Headers.Entries.insert({"Range", fmt::format("bytes={}", SB.ToView())}); + } + std::string Url = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}?supportsRedirect={}", + Namespace, + BucketId, + BuildId, + Hash.ToHexString(), + m_AllowRedirect ? "true"sv : "false"sv); + + HttpClient::Response Response = m_HttpClient.Download(Url, TempFolderPath, Headers); + if (Response.StatusCode == HttpResponseCode::RangeNotSatisfiable && Ranges.size() > 1) + { + // Requests to Jupiter that is not served via nginx (content not stored locally in the file system) can not serve multi-range + // requests (asp.net limitation) This rejection is not implemented as of 2026-03-02, it is in the backlog (@joakim.lindqvist) + // If we encounter this error we fall back to a single range which covers all the requested ranges + uint64_t RangeStart = Ranges.front().first; + uint64_t RangeEnd = Ranges.back().first + Ranges.back().second - 1; + Headers.Entries.insert_or_assign("Range", fmt::format("bytes={}-{}", RangeStart, RangeEnd)); + Response = m_HttpClient.Download(Url, TempFolderPath, Headers); + } + if (Response.IsSuccess()) + { + // If we get a redirect to S3 or a non-Jupiter endpoint the content type will not be correct, validate it and set it + if (m_AllowRedirect && (Response.ResponsePayload.GetContentType() == HttpContentType::kBinary)) + { + IoHash ValidateRawHash; + uint64_t ValidateRawSize = 0; + if (!Headers.Entries.contains("Range")) + { + ZEN_ASSERT_SLOW(CompressedBuffer::ValidateCompressedHeader(Response.ResponsePayload, + ValidateRawHash, + ValidateRawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)); + ZEN_ASSERT_SLOW(ValidateRawHash == Hash); + ZEN_ASSERT_SLOW(ValidateRawSize > 0); + ZEN_UNUSED(ValidateRawHash, ValidateRawSize); + Response.ResponsePayload.SetContentType(ZenContentType::kCompressedBinary); + } + } + } + BuildBlobRangesResult Result = {detail::ConvertResponse(Response, "JupiterSession::GetBuildBlob"sv)}; + Result.Ranges = Response.GetRanges(Ranges); + return Result; +} + JupiterResult JupiterSession::PutBlockMetadata(std::string_view Namespace, std::string_view BucketId, diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index c42373e4d..3400cdbf5 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -478,7 +478,6 @@ public: *m_BuildStorage, m_BuildCacheStorage.get(), m_BuildId, - m_OplogBuildPartId, BlockHashes, /*AttemptFallback*/ false, /*IsQuiet*/ false, @@ -549,7 +548,7 @@ public: return Result; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override { ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); @@ -559,25 +558,90 @@ public: try { - if (m_BuildCacheStorage) + if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) { - IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash, Range.Offset, Range.Bytes); + IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash); if (CachedBlob) { Result.Bytes = std::move(CachedBlob); } } - if (!Result.Bytes) + if (!Result.Bytes && SourceMode != ESourceMode::kCacheOnly) { - Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash, Range.Offset, Range.Bytes); + Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) { - if (!Range) + m_BuildCacheStorage->PutBuildBlob(m_BuildId, + RawHash, + Result.Bytes.GetContentType(), + CompositeBuffer(SharedBuffer(Result.Bytes))); + } + } + } + catch (const HttpClientError& Ex) + { + Result.ErrorCode = MakeErrorCode(Ex); + Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + RawHash, + Ex.what()); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed getting blob {}/{}/{}/{}/{}. Reason: '{}'", + m_BuildStorageHttp.GetBaseUri(), + m_Namespace, + m_Bucket, + m_BuildId, + RawHash, + Ex.what()); + } + + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span> Ranges, + ESourceMode SourceMode) override + { + LoadAttachmentRangesResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); + + try + { + if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) + { + BuildStorageCache::BuildBlobRanges BlobRanges = m_BuildCacheStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); + if (BlobRanges.PayloadBuffer) + { + Result.Bytes = std::move(BlobRanges.PayloadBuffer); + Result.Ranges = std::move(BlobRanges.Ranges); + } + } + if (!Result.Bytes && SourceMode != ESourceMode::kCacheOnly) + { + BuildStorageBase::BuildBlobRanges BlobRanges = m_BuildStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); + if (BlobRanges.PayloadBuffer) + { + Result.Bytes = std::move(BlobRanges.PayloadBuffer); + Result.Ranges = std::move(BlobRanges.Ranges); + + if (Result.Ranges.empty()) { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - RawHash, - Result.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(Result.Bytes))); + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3/Replicated + // Upload to cache (if enabled) + if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) + { + m_BuildCacheStorage->PutBuildBlob(m_BuildId, + RawHash, + Result.Bytes.GetContentType(), + CompositeBuffer(SharedBuffer(Result.Bytes))); + } } } } @@ -585,28 +649,32 @@ public: catch (const HttpClientError& Ex) { Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'", + Ranges.size(), m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, m_BuildId, + RawHash, Ex.what()); } catch (const std::exception& Ex) { Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Failed listing known blocks for {}/{}/{}/{}. Reason: '{}'", + Result.Reason = fmt::format("Failed getting {} ranges for blob {}/{}/{}/{}/{}. Reason: '{}'", + Ranges.size(), m_BuildStorageHttp.GetBaseUri(), m_Namespace, m_Bucket, m_BuildId, + RawHash, Ex.what()); } return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override { LoadAttachmentsResult Result; Stopwatch Timer; @@ -614,7 +682,7 @@ public: std::vector AttachmentsLeftToFind = RawHashes; - if (m_BuildCacheStorage) + if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) { std::vector ExistCheck = m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes); if (ExistCheck.size() == RawHashes.size()) @@ -648,7 +716,7 @@ public: for (const IoHash& Hash : AttachmentsLeftToFind) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; diff --git a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp index ec7fb7bbc..f950fd46c 100644 --- a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp @@ -228,28 +228,62 @@ public: return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override { - Stopwatch Timer; - LoadAttachmentResult Result; - std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); - if (!IsFile(ChunkPath)) + Stopwatch Timer; + LoadAttachmentResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); - Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - return Result; + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!IsFile(ChunkPath)) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); + Result.Reason = + fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; + } + { + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); + Result.Bytes = ChunkFile.ReadAll(); + } } + AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span> Ranges, + ESourceMode SourceMode) override + { + Stopwatch Timer; + LoadAttachmentRangesResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - BasicFile ChunkFile; - ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); - if (Range) + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!IsFile(ChunkPath)) { - Result.Bytes = ChunkFile.ReadRange(Range.Offset, Range.Bytes); + Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); + Result.Reason = + fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; } - else { - Result.Bytes = ChunkFile.ReadAll(); + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); + + uint64_t Start = Ranges.front().first; + uint64_t Length = Ranges.back().first + Ranges.back().second - Ranges.front().first; + + Result.Bytes = ChunkFile.ReadRange(Start, Length); + Result.Ranges.reserve(Ranges.size()); + for (const std::pair& Range : Ranges) + { + Result.Ranges.push_back(std::make_pair(Range.first - Start, Range.second)); + } } } AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); @@ -257,13 +291,13 @@ public: return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override { Stopwatch Timer; LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); if (ChunkResult.ErrorCode) { ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; diff --git a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp index f8179831c..514484f30 100644 --- a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp @@ -223,34 +223,62 @@ public: return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override { - JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); - JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); - AddStats(GetResult); - - LoadAttachmentResult Result{ConvertResult(GetResult), std::move(GetResult.Response)}; - if (GetResult.ErrorCode) + LoadAttachmentResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", - m_JupiterClient->ServiceUrl(), - m_Namespace, - RawHash, - Result.Reason); + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + AddStats(GetResult); + + Result = {ConvertResult(GetResult), std::move(GetResult.Response)}; + if (GetResult.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", + m_JupiterClient->ServiceUrl(), + m_Namespace, + RawHash, + Result.Reason); + } } - if (!Result.ErrorCode && Range) + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span> Ranges, + ESourceMode SourceMode) override + { + LoadAttachmentRangesResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - Result.Bytes = IoBuffer(Result.Bytes, Range.Offset, Range.Bytes); + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + AddStats(GetResult); + + Result = LoadAttachmentRangesResult{ConvertResult(GetResult), std::move(GetResult.Response)}; + if (GetResult.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", + m_JupiterClient->ServiceUrl(), + m_Namespace, + RawHash, + Result.Reason); + } + else + { + Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); + } } return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override { LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, {}); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 2a9da6f58..1882f599a 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -339,9 +339,10 @@ namespace remotestore_impl { uint64_t ChunkSize = It.second.GetCompressedSize(); Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); } - ZEN_INFO("Loaded {} bulk attachments in {}", - Chunks.size(), - NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000))); + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("Loaded {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000)))); if (RemoteResult.IsError()) { return; @@ -446,7 +447,7 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash, {}); + RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); if (BlockResult.ErrorCode) { ReportMessage(OptionalContext, @@ -506,50 +507,100 @@ namespace remotestore_impl { IoHash RawHash; uint64_t RawSize; CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Bytes), RawHash, RawSize); + + std::string ErrorString; + if (!Compressed) { - if (RetriesLeft > 0) + ErrorString = + fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash); + } + else if (RawHash != BlockHash) + { + ErrorString = fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash); + } + else if (CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); !BlockPayload) + { + ErrorString = fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash); + } + else + { + uint64_t PotentialSize = 0; + uint64_t UsedSize = 0; + uint64_t BlockSize = BlockPayload.GetSize(); + + uint64_t BlockHeaderSize = 0; + + bool StoreChunksOK = IterateChunkBlock( + BlockPayload.Flatten(), + [&AllNeededPartialChunkHashesLookup, + &ChunkDownloadedFlags, + &WriteAttachmentBuffers, + &WriteRawHashes, + &Info, + &PotentialSize](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(AttachmentRawHash); + if (ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) + { + WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); + IoHash RawHash; + uint64_t RawSize; + ZEN_ASSERT(CompressedBuffer::ValidateCompressedHeader( + WriteAttachmentBuffers.back(), + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)); + ZEN_ASSERT(RawHash == AttachmentRawHash); + WriteRawHashes.emplace_back(AttachmentRawHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } + }, + BlockHeaderSize); + + if (!StoreChunksOK) { - ReportMessage( - OptionalContext, - fmt::format( - "Block attachment {} is malformed, can't parse as compressed binary, retrying download", - BlockHash)); - return DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, - AttachmentsDownloadLatch, - AttachmentsWriteLatch, - RemoteResult, - Info, - LoadAttachmentsTimer, - DownloadStartMS, - BlockHash, - AllNeededPartialChunkHashesLookup, - ChunkDownloadedFlags, - RetriesLeft - 1); + ErrorString = fmt::format("Invalid format for block {}", BlockHash); + } + else + { + if (!WriteAttachmentBuffers.empty()) + { + std::vector Results = + ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + for (size_t Index = 0; Index < Results.size(); Index++) + { + const CidStore::InsertResult& Result = Results[Index]; + if (Result.New) + { + Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); + Info.AttachmentsStored.fetch_add(1); + UsedSize += WriteAttachmentBuffers[Index].GetSize(); + } + } + if (UsedSize < BlockSize) + { + ZEN_DEBUG("Used {} (skipping {}) out of {} for block {} ({} %) (use of matching {}%)", + NiceBytes(UsedSize), + NiceBytes(BlockSize - UsedSize), + NiceBytes(BlockSize), + BlockHash, + (100 * UsedSize) / BlockSize, + PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); + } + } } - ReportMessage( - OptionalContext, - fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash)); - RemoteResult.SetError( - gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} is malformed, can't parse as compressed binary", BlockHash), - {}); - return; } - CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); - if (!BlockPayload) + + if (!ErrorString.empty()) { if (RetriesLeft > 0) { - ReportMessage( - OptionalContext, - fmt::format("Block attachment {} is malformed, can't decompress payload, retrying download", - BlockHash)); + ReportMessage(OptionalContext, fmt::format("{}, retrying download", ErrorString)); + return DownloadAndSaveBlock(ChunkStore, RemoteStore, IgnoreMissingAttachments, @@ -567,94 +618,12 @@ namespace remotestore_impl { ChunkDownloadedFlags, RetriesLeft - 1); } - ReportMessage(OptionalContext, - fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash)); - RemoteResult.SetError( - gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} is malformed, can't decompress payload", BlockHash), - {}); - return; - } - if (RawHash != BlockHash) - { - ReportMessage(OptionalContext, - fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash)); - RemoteResult.SetError( - gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Block attachment {} has mismatching raw hash ({})", BlockHash, RawHash), - {}); - return; - } - - uint64_t PotentialSize = 0; - uint64_t UsedSize = 0; - uint64_t BlockSize = BlockPayload.GetSize(); - - uint64_t BlockHeaderSize = 0; - - bool StoreChunksOK = IterateChunkBlock( - BlockPayload.Flatten(), - [&AllNeededPartialChunkHashesLookup, - &ChunkDownloadedFlags, - &WriteAttachmentBuffers, - &WriteRawHashes, - &Info, - &PotentialSize](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { - auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(AttachmentRawHash); - if (ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) - { - bool Expected = false; - if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) - { - WriteAttachmentBuffers.emplace_back(Chunk.GetCompressed().Flatten().AsIoBuffer()); - IoHash RawHash; - uint64_t RawSize; - ZEN_ASSERT( - CompressedBuffer::ValidateCompressedHeader(WriteAttachmentBuffers.back(), - RawHash, - RawSize, - /*OutOptionalTotalCompressedSize*/ nullptr)); - ZEN_ASSERT(RawHash == AttachmentRawHash); - WriteRawHashes.emplace_back(AttachmentRawHash); - PotentialSize += WriteAttachmentBuffers.back().GetSize(); - } - } - }, - BlockHeaderSize); - - if (!StoreChunksOK) - { - ReportMessage(OptionalContext, - fmt::format("Block attachment {} has invalid format ({}): {}", - BlockHash, - RemoteResult.GetError(), - RemoteResult.GetErrorReason())); - RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Invalid format for block {}", BlockHash), - {}); - return; - } - - if (!WriteAttachmentBuffers.empty()) - { - auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); - for (size_t Index = 0; Index < Results.size(); Index++) + else { - const auto& Result = Results[Index]; - if (Result.New) - { - Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); - Info.AttachmentsStored.fetch_add(1); - UsedSize += WriteAttachmentBuffers[Index].GetSize(); - } + ReportMessage(OptionalContext, ErrorString); + RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), ErrorString, {}); + return; } - ZEN_DEBUG("Used {} (matching {}) out of {} for block {} ({} %) (use of matching {}%)", - NiceBytes(UsedSize), - NiceBytes(PotentialSize), - NiceBytes(BlockSize), - BlockHash, - (100 * UsedSize) / BlockSize, - PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); } } catch (const std::exception& Ex) @@ -676,6 +645,119 @@ namespace remotestore_impl { WorkerThreadPool::EMode::EnableBacklog); }; + bool DownloadPartialBlock(RemoteProjectStore& RemoteStore, + bool IgnoreMissingAttachments, + JobContext* OptionalContext, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + double& DownloadTimeSeconds, + const ChunkBlockDescription& BlockDescription, + bool BlockExistsInCache, + std::span BlockRangeDescriptors, + size_t BlockRangeIndexStart, + size_t BlockRangeCount, + std::function> OffsetAndLengths)>&& OnDownloaded) + { + std::vector> Ranges; + Ranges.reserve(BlockRangeDescriptors.size()); + for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; BlockRangeIndex++) + { + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRangeDescriptors[BlockRangeIndex]; + Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); + } + + if (BlockExistsInCache) + { + RemoteProjectStore::LoadAttachmentRangesResult BlockResult = + RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, Ranges, RemoteProjectStore::ESourceMode::kCacheOnly); + DownloadTimeSeconds += BlockResult.ElapsedSeconds; + if (RemoteResult.IsError()) + { + return false; + } + if (!BlockResult.ErrorCode && BlockResult.Bytes) + { + if (BlockResult.Ranges.size() != Ranges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + Ranges.size(), + BlockDescription.BlockHash, + BlockResult.Ranges.size())); + } + OnDownloaded(std::move(BlockResult.Bytes), BlockRangeIndexStart, BlockResult.Ranges); + return true; + } + } + + const size_t MaxRangesPerRequestToJupiter = RemoteProjectStore::MaxRangeCountPerRequest; + + size_t SubBlockRangeCount = BlockRangeCount; + size_t SubRangeCountComplete = 0; + std::span> RangesSpan(Ranges); + while (SubRangeCountComplete < SubBlockRangeCount) + { + if (RemoteResult.IsError()) + { + break; + } + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, MaxRangesPerRequestToJupiter); + size_t SubRangeStartIndex = BlockRangeIndexStart + SubRangeCountComplete; + + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + RemoteProjectStore::LoadAttachmentRangesResult BlockResult = + RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, SubRanges, RemoteProjectStore::ESourceMode::kHostOnly); + DownloadTimeSeconds += BlockResult.ElapsedSeconds; + if (RemoteResult.IsError()) + { + return false; + } + if (BlockResult.ErrorCode || !BlockResult.Bytes) + { + ReportMessage(OptionalContext, + fmt::format("Failed to download {} ranges from block attachment '{}' ({}): {}", + SubRanges.size(), + BlockDescription.BlockHash, + BlockResult.ErrorCode, + BlockResult.Reason)); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + return false; + } + } + else + { + if (BlockResult.Ranges.empty()) + { + // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 + // Use the whole payload for the remaining ranges + SubRangeCount = Ranges.size() - SubRangeCountComplete; + OnDownloaded(std::move(BlockResult.Bytes), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount)); + } + else + { + if (BlockResult.Ranges.size() != SubRanges.size()) + { + throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", + SubRanges.size(), + BlockDescription.BlockHash, + BlockResult.Ranges.size())); + } + OnDownloaded(std::move(BlockResult.Bytes), SubRangeStartIndex, BlockResult.Ranges); + } + } + + SubRangeCountComplete += SubRangeCount; + } + return true; + } + void DownloadAndSavePartialBlock(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, bool IgnoreMissingAttachments, @@ -689,6 +771,7 @@ namespace remotestore_impl { Stopwatch& LoadAttachmentsTimer, std::atomic_uint64_t& DownloadStartMS, const ChunkBlockDescription& BlockDescription, + bool BlockExistsInCache, std::span BlockRangeDescriptors, size_t BlockRangeIndexStart, size_t BlockRangeCount, @@ -710,13 +793,14 @@ namespace remotestore_impl { &DownloadStartMS, IgnoreMissingAttachments, OptionalContext, - RetriesLeft, BlockDescription, + BlockExistsInCache, BlockRangeDescriptors, BlockRangeIndexStart, BlockRangeCount, &AllNeededPartialChunkHashesLookup, - ChunkDownloadedFlags]() { + ChunkDownloadedFlags, + RetriesLeft]() { ZEN_TRACE_CPU("DownloadBlockRanges"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -728,230 +812,240 @@ namespace remotestore_impl { double DownloadElapsedSeconds = 0; uint64_t DownloadedBytes = 0; - for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; - BlockRangeIndex++) - { - if (RemoteResult.IsError()) - { - return; - } - - const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = BlockRangeDescriptors[BlockRangeIndex]; + bool Success = DownloadPartialBlock( + RemoteStore, + IgnoreMissingAttachments, + OptionalContext, + RemoteResult, + Info, + DownloadElapsedSeconds, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeIndexStart, + BlockRangeCount, + [&](IoBuffer&& Buffer, + size_t BlockRangeStartIndex, + std::span> OffsetAndLengths) { + uint64_t BlockPartSize = Buffer.GetSize(); + DownloadedBytes += BlockPartSize; + + Info.AttachmentBlockRangeBytesDownloaded.fetch_add(BlockPartSize); + Info.AttachmentBlocksRangesDownloaded++; + + AttachmentsWriteLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&AttachmentsWriteLatch, + &ChunkStore, + &RemoteStore, + &NetworkWorkerPool, + &WorkerPool, + &AttachmentsDownloadLatch, + &RemoteResult, + &Info, + &LoadAttachmentsTimer, + &DownloadStartMS, + IgnoreMissingAttachments, + OptionalContext, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeStartIndex, + &AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + RetriesLeft, + BlockPayload = std::move(Buffer), + OffsetAndLengths = + std::vector>(OffsetAndLengths.begin(), OffsetAndLengths.end())]() { + auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); + try + { + ZEN_ASSERT(BlockPayload.Size() > 0); - RemoteProjectStore::LoadAttachmentResult BlockResult = - RemoteStore.LoadAttachment(BlockDescription.BlockHash, - {.Offset = BlockRange.RangeStart, .Bytes = BlockRange.RangeLength}); - if (BlockResult.ErrorCode) - { - ReportMessage(OptionalContext, - fmt::format("Failed to download block attachment '{}' range {},{} ({}): {}", - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeLength, - BlockResult.ErrorCode, - BlockResult.Reason)); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) - { - RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); - } - return; - } - if (RemoteResult.IsError()) - { - return; - } - uint64_t BlockPartSize = BlockResult.Bytes.GetSize(); - if (BlockPartSize != BlockRange.RangeLength) - { - std::string ErrorString = - fmt::format("Failed to download block attachment '{}' range {},{}, got {} bytes ({}): {}", - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeLength, - BlockPartSize, - RemoteResult.GetError(), - RemoteResult.GetErrorReason()); - - ReportMessage(OptionalContext, ErrorString); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) - { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), - "Mismatching block part range received", - ErrorString); - } - return; - } - Info.AttachmentBlocksRangesDownloaded.fetch_add(1); + size_t RangeCount = OffsetAndLengths.size(); + for (size_t RangeOffset = 0; RangeOffset < RangeCount; RangeOffset++) + { + if (RemoteResult.IsError()) + { + return; + } - DownloadElapsedSeconds += BlockResult.ElapsedSeconds; - DownloadedBytes += BlockPartSize; + const ChunkBlockAnalyser::BlockRangeDescriptor& BlockRange = + BlockRangeDescriptors[BlockRangeStartIndex + RangeOffset]; + const std::pair& OffsetAndLength = OffsetAndLengths[RangeOffset]; + IoBuffer BlockRangeBuffer(BlockPayload, OffsetAndLength.first, OffsetAndLength.second); - Info.AttachmentBlockRangeBytesDownloaded.fetch_add(BlockPartSize); + std::vector WriteAttachmentBuffers; + std::vector WriteRawHashes; - AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&AttachmentsDownloadLatch, - &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, - &RemoteResult, - &Info, - &LoadAttachmentsTimer, - &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, - RetriesLeft, - BlockDescription, - BlockRange, - &AllNeededPartialChunkHashesLookup, - ChunkDownloadedFlags, - BlockPayload = std::move(BlockResult.Bytes)]() { - auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); - if (RemoteResult.IsError()) - { - return; - } - try - { - ZEN_ASSERT(BlockPayload.Size() > 0); - std::vector WriteAttachmentBuffers; - std::vector WriteRawHashes; + uint64_t PotentialSize = 0; + uint64_t UsedSize = 0; + uint64_t BlockPartSize = BlockRangeBuffer.GetSize(); - uint64_t PotentialSize = 0; - uint64_t UsedSize = 0; - uint64_t BlockPartSize = BlockPayload.GetSize(); + uint32_t OffsetInBlock = 0; + for (uint32_t ChunkBlockIndex = BlockRange.ChunkBlockIndexStart; + ChunkBlockIndex < BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount; + ChunkBlockIndex++) + { + if (RemoteResult.IsError()) + { + break; + } - uint32_t OffsetInBlock = 0; - for (uint32_t ChunkBlockIndex = BlockRange.ChunkBlockIndexStart; - ChunkBlockIndex < BlockRange.ChunkBlockIndexStart + BlockRange.ChunkBlockIndexCount; - ChunkBlockIndex++) - { - const uint32_t ChunkCompressedSize = BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; - const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; + const uint32_t ChunkCompressedSize = + BlockDescription.ChunkCompressedLengths[ChunkBlockIndex]; + const IoHash& ChunkHash = BlockDescription.ChunkRawHashes[ChunkBlockIndex]; - if (auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(ChunkHash); - ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) - { - bool Expected = false; - if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, true)) - { - IoHash VerifyChunkHash; - uint64_t VerifyChunkSize; - CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed( - SharedBuffer(IoBuffer(BlockPayload, OffsetInBlock, ChunkCompressedSize)), - VerifyChunkHash, - VerifyChunkSize); - if (!CompressedChunk) + if (auto ChunkIndexIt = AllNeededPartialChunkHashesLookup.find(ChunkHash); + ChunkIndexIt != AllNeededPartialChunkHashesLookup.end()) { - std::string ErrorString = fmt::format( - "Chunk at {},{} in block attachment '{}' is not a valid compressed buffer", - OffsetInBlock, - ChunkCompressedSize, - BlockDescription.BlockHash); - ReportMessage(OptionalContext, ErrorString); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!ChunkDownloadedFlags[ChunkIndexIt->second]) { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), - "Malformed chunk block", - ErrorString); + IoHash VerifyChunkHash; + uint64_t VerifyChunkSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed( + SharedBuffer(IoBuffer(BlockRangeBuffer, OffsetInBlock, ChunkCompressedSize)), + VerifyChunkHash, + VerifyChunkSize); + + std::string ErrorString; + + if (!CompressedChunk) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' is not a valid compressed buffer", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash); + } + else if (VerifyChunkHash != ChunkHash) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching hash, expected " + "{}, got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + ChunkHash, + VerifyChunkHash); + } + else if (VerifyChunkSize != BlockDescription.ChunkRawLengths[ChunkBlockIndex]) + { + ErrorString = fmt::format( + "Chunk at {},{} in block attachment '{}' has mismatching raw size, " + "expected {}, " + "got {}", + OffsetInBlock, + ChunkCompressedSize, + BlockDescription.BlockHash, + BlockDescription.ChunkRawLengths[ChunkBlockIndex], + VerifyChunkSize); + } + + if (!ErrorString.empty()) + { + if (RetriesLeft > 0) + { + ReportMessage(OptionalContext, + fmt::format("{}, retrying download", ErrorString)); + return DownloadAndSavePartialBlock(ChunkStore, + RemoteStore, + IgnoreMissingAttachments, + OptionalContext, + NetworkWorkerPool, + WorkerPool, + AttachmentsDownloadLatch, + AttachmentsWriteLatch, + RemoteResult, + Info, + LoadAttachmentsTimer, + DownloadStartMS, + BlockDescription, + BlockExistsInCache, + BlockRangeDescriptors, + BlockRangeStartIndex, + RangeCount, + AllNeededPartialChunkHashesLookup, + ChunkDownloadedFlags, + RetriesLeft - 1); + } + + ReportMessage(OptionalContext, ErrorString); + Info.MissingAttachmentCount.fetch_add(1); + if (!IgnoreMissingAttachments) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), + "Malformed chunk block", + ErrorString); + } + } + else + { + bool Expected = false; + if (ChunkDownloadedFlags[ChunkIndexIt->second].compare_exchange_strong(Expected, + true)) + { + WriteAttachmentBuffers.emplace_back( + CompressedChunk.GetCompressed().Flatten().AsIoBuffer()); + WriteRawHashes.emplace_back(ChunkHash); + PotentialSize += WriteAttachmentBuffers.back().GetSize(); + } + } } - continue; } - if (VerifyChunkHash != ChunkHash) + OffsetInBlock += ChunkCompressedSize; + } + + if (!WriteAttachmentBuffers.empty()) + { + std::vector Results = + ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + for (size_t Index = 0; Index < Results.size(); Index++) { - std::string ErrorString = fmt::format( - "Chunk at {},{} in block attachment '{}' has mismatching hash, expected {}, got {}", - OffsetInBlock, - ChunkCompressedSize, - BlockDescription.BlockHash, - ChunkHash, - VerifyChunkHash); - ReportMessage(OptionalContext, ErrorString); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + const CidStore::InsertResult& Result = Results[Index]; + if (Result.New) { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), - "Malformed chunk block", - ErrorString); + Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); + Info.AttachmentsStored.fetch_add(1); + UsedSize += WriteAttachmentBuffers[Index].GetSize(); } - continue; } - if (VerifyChunkSize != BlockDescription.ChunkRawLengths[ChunkBlockIndex]) + if (UsedSize < BlockPartSize) { - std::string ErrorString = fmt::format( - "Chunk at {},{} in block attachment '{}' has mismatching raw size, expected {}, " - "got {}", - OffsetInBlock, - ChunkCompressedSize, + ZEN_DEBUG( + "Used {} (skipping {}) out of {} for block {} range {}, {} ({} %) (use of matching " + "{}%)", + NiceBytes(UsedSize), + NiceBytes(BlockPartSize - UsedSize), + NiceBytes(BlockPartSize), BlockDescription.BlockHash, - BlockDescription.ChunkRawLengths[ChunkBlockIndex], - VerifyChunkSize); - ReportMessage(OptionalContext, ErrorString); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) - { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), - "Malformed chunk block", - ErrorString); - } - continue; + BlockRange.RangeStart, + BlockRange.RangeLength, + (100 * UsedSize) / BlockPartSize, + PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); } - - WriteAttachmentBuffers.emplace_back(CompressedChunk.GetCompressed().Flatten().AsIoBuffer()); - WriteRawHashes.emplace_back(ChunkHash); - PotentialSize += WriteAttachmentBuffers.back().GetSize(); } } - OffsetInBlock += ChunkCompressedSize; } - - if (!WriteAttachmentBuffers.empty()) + catch (const std::exception& Ex) { - auto Results = ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); - for (size_t Index = 0; Index < Results.size(); Index++) - { - const auto& Result = Results[Index]; - if (Result.New) - { - Info.AttachmentBytesStored.fetch_add(WriteAttachmentBuffers[Index].GetSize()); - Info.AttachmentsStored.fetch_add(1); - UsedSize += WriteAttachmentBuffers[Index].GetSize(); - } - } - ZEN_DEBUG("Used {} (matching {}) out of {} for block {} range {}, {} ({} %) (use of matching {}%)", - NiceBytes(UsedSize), - NiceBytes(PotentialSize), - NiceBytes(BlockPartSize), - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeLength, - (100 * UsedSize) / BlockPartSize, - PotentialSize > 0 ? (UsedSize * 100) / PotentialSize : 0); + RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), + fmt::format("Failed saving {} ranges from block attachment {}", + OffsetAndLengths.size(), + BlockDescription.BlockHash), + Ex.what()); } - } - catch (const std::exception& Ex) - { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Failed save block attachment {} range {}, {}", - BlockDescription.BlockHash, - BlockRange.RangeStart, - BlockRange.RangeLength), - Ex.what()); - } - }, - WorkerThreadPool::EMode::EnableBacklog); + }, + WorkerThreadPool::EMode::EnableBacklog); + }); + if (Success) + { + ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})", + BlockRangeCount, + BlockDescription.BlockHash, + NiceTimeSpanMs(static_cast(DownloadElapsedSeconds * 1000)), + NiceBytes(DownloadedBytes)); } - - ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})", - BlockRangeCount, - BlockDescription.BlockHash, - NiceTimeSpanMs(static_cast(DownloadElapsedSeconds * 1000)), - NiceBytes(DownloadedBytes)); } catch (const std::exception& Ex) { @@ -1002,7 +1096,7 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash, {}); + RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); if (AttachmentResult.ErrorCode) { ReportMessage(OptionalContext, @@ -3115,6 +3209,12 @@ ParseOplogContainer( std::unordered_set NeededAttachments; { CbArrayView OpsArray = OutOplogSection["ops"sv].AsArrayView(); + + size_t OpCount = OpsArray.Num(); + size_t OpsCompleteCount = 0; + + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Scanning {} ops for attachments", OpCount)); + for (CbFieldView OpEntry : OpsArray) { OpEntry.IterateAttachments([&](CbFieldView FieldView) { NeededAttachments.insert(FieldView.AsAttachment()); }); @@ -3124,6 +3224,16 @@ ParseOplogContainer( .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, .Reason = "Operation cancelled"}; } + OpsCompleteCount++; + if ((OpsCompleteCount & 4095) == 0) + { + remotestore_impl::ReportProgress( + OptionalContext, + "Scanning oplog"sv, + fmt::format("{} attachments found, {} ops remaining...", NeededAttachments.size(), OpCount - OpsCompleteCount), + OpCount, + OpCount - OpsCompleteCount); + } } } { @@ -3151,13 +3261,27 @@ ParseOplogContainer( { ChunkedInfo Chunked = ReadChunkedInfo(ChunkedFileView); + size_t NeededChunkAttachmentCount = 0; + OnReferencedAttachments(Chunked.ChunkHashes); - NeededAttachments.insert(Chunked.ChunkHashes.begin(), Chunked.ChunkHashes.end()); + for (const IoHash& ChunkHash : Chunked.ChunkHashes) + { + if (!HasAttachment(ChunkHash)) + { + if (NeededAttachments.insert(ChunkHash).second) + { + NeededChunkAttachmentCount++; + } + } + } OnChunkedAttachment(Chunked); - ZEN_INFO("Requesting chunked attachment '{}' ({}) built from {} chunks", - Chunked.RawHash, - NiceBytes(Chunked.RawSize), - Chunked.ChunkHashes.size()); + + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("Requesting chunked attachment '{}' ({}) built from {} chunks, need {} chunks", + Chunked.RawHash, + NiceBytes(Chunked.RawSize), + Chunked.ChunkHashes.size(), + NeededChunkAttachmentCount)); } } if (remotestore_impl::IsCancelled(OptionalContext)) @@ -3490,8 +3614,16 @@ LoadOplog(CidStore& ChunkStore, std::vector DownloadedViaLegacyChunkFlag(AllNeededChunkHashes.size(), false); ChunkBlockAnalyser::BlockResult PartialBlocksResult; + remotestore_impl::ReportMessage(OptionalContext, fmt::format("Fetching descriptions for {} blocks", BlockHashes.size())); + RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions = RemoteStore.GetBlockDescriptions(BlockHashes); - std::vector BlocksWithDescription; + + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("GetBlockDescriptions took {}. Found {} blocks", + NiceTimeSpanMs(uint64_t(BlockDescriptions.ElapsedSeconds * 1000)), + BlockDescriptions.Blocks.size())); + + std::vector BlocksWithDescription; BlocksWithDescription.reserve(BlockDescriptions.Blocks.size()); for (const ChunkBlockDescription& BlockDescription : BlockDescriptions.Blocks) { @@ -3547,6 +3679,7 @@ LoadOplog(CidStore& ChunkStore, if (!AllNeededChunkHashes.empty()) { std::vector PartialBlockDownloadModes; + std::vector BlockExistsInCache; if (PartialBlockRequestMode == EPartialBlockRequestMode::Off) { @@ -3558,21 +3691,25 @@ LoadOplog(CidStore& ChunkStore, RemoteStore.AttachmentExistsInCache(BlocksWithDescription); if (CacheExistsResult.ErrorCode != 0 || CacheExistsResult.HasBody.size() != BlocksWithDescription.size()) { - CacheExistsResult.HasBody.resize(BlocksWithDescription.size(), false); + BlockExistsInCache.resize(BlocksWithDescription.size(), false); + } + else + { + BlockExistsInCache = std::move(CacheExistsResult.HasBody); } PartialBlockDownloadModes.reserve(BlocksWithDescription.size()); - for (bool ExistsInCache : CacheExistsResult.HasBody) + for (bool ExistsInCache : BlockExistsInCache) { if (PartialBlockRequestMode == EPartialBlockRequestMode::All) { - PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); } else if (PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) { - PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } else if (PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) @@ -3584,13 +3721,14 @@ LoadOplog(CidStore& ChunkStore, } ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size()); - - ChunkBlockAnalyser PartialAnalyser(*LogOutput, - BlockDescriptions.Blocks, - ChunkBlockAnalyser::Options{.IsQuiet = false, - .IsVerbose = false, - .HostLatencySec = HostLatencySec, - .HostHighSpeedLatencySec = CacheLatencySec}); + ChunkBlockAnalyser PartialAnalyser( + *LogOutput, + BlockDescriptions.Blocks, + ChunkBlockAnalyser::Options{.IsQuiet = false, + .IsVerbose = false, + .HostLatencySec = HostLatencySec, + .HostHighSpeedLatencySec = CacheLatencySec, + .HostMaxRangeCountPerRequest = RemoteProjectStore::MaxRangeCountPerRequest}); std::vector NeededBlocks = PartialAnalyser.GetNeeded(AllNeededPartialChunkHashesLookup, @@ -3641,12 +3779,13 @@ LoadOplog(CidStore& ChunkStore, LoadAttachmentsTimer, DownloadStartMS, BlockDescriptions.Blocks[CurrentBlockRange.BlockIndex], + BlockExistsInCache[CurrentBlockRange.BlockIndex], PartialBlocksResult.BlockRanges, BlockRangeIndex, RangeCount, AllNeededPartialChunkHashesLookup, ChunkDownloadedFlags, - 3); + /* RetriesLeft*/ 3); BlockRangeIndex += RangeCount; } @@ -3668,12 +3807,23 @@ LoadOplog(CidStore& ChunkStore, { PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load(); } - remotestore_impl::ReportProgress( - OptionalContext, - "Loading attachments"sv, - fmt::format("{} remaining. {}", Remaining, remotestore_impl::GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), - AttachmentCount.load(), - Remaining); + + uint64_t AttachmentsDownloaded = + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); + uint64_t AttachmentBytesDownloaded = Info.AttachmentBlockBytesDownloaded.load() + Info.AttachmentBlockRangeBytesDownloaded.load() + + Info.AttachmentBytesDownloaded.load(); + + remotestore_impl::ReportProgress(OptionalContext, + "Loading attachments"sv, + fmt::format("{} ({}) downloaded, {} ({}) stored, {} remaining. {}", + AttachmentsDownloaded, + NiceBytes(AttachmentBytesDownloaded), + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Remaining, + remotestore_impl::GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), + AttachmentCount.load(), + Remaining); } if (DownloadStartMS != (uint64_t)-1) { @@ -3700,11 +3850,12 @@ LoadOplog(CidStore& ChunkStore, RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); } } - remotestore_impl::ReportProgress(OptionalContext, - "Writing attachments"sv, - fmt::format("{} remaining.", Remaining), - AttachmentCount.load(), - Remaining); + remotestore_impl::ReportProgress( + OptionalContext, + "Writing attachments"sv, + fmt::format("{} ({}), {} remaining.", Info.AttachmentsStored.load(), NiceBytes(Info.AttachmentBytesStored.load()), Remaining), + AttachmentCount.load(), + Remaining); } if (AttachmentCount.load() > 0) @@ -3867,18 +4018,20 @@ LoadOplog(CidStore& ChunkStore, TmpFile.Close(); TmpBuffer = IoBufferBuilder::MakeFromTemporaryFile(TempFileName); } + uint64_t TmpBufferSize = TmpBuffer.GetSize(); CidStore::InsertResult InsertResult = ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace); if (InsertResult.New) { - Info.AttachmentBytesStored.fetch_add(TmpBuffer.GetSize()); + Info.AttachmentBytesStored.fetch_add(TmpBufferSize); Info.AttachmentsStored.fetch_add(1); } - ZEN_INFO("Dechunked attachment {} ({}) in {}", - Chunked.RawHash, - NiceBytes(Chunked.RawSize), - NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + remotestore_impl::ReportMessage(OptionalContext, + fmt::format("Dechunked attachment {} ({}) in {}", + Chunked.RawHash, + NiceBytes(Chunked.RawSize), + NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); } catch (const std::exception& Ex) { diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp index b4c1156ac..ef82c45e0 100644 --- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp @@ -157,55 +157,59 @@ public: return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override { - std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); - - CbObject Request; + LoadAttachmentsResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - CbObjectWriter RequestWriter; - RequestWriter.AddString("method"sv, "getchunks"sv); - RequestWriter.BeginObject("Request"sv); + std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); + + CbObject Request; { - RequestWriter.BeginArray("Chunks"sv); + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "getchunks"sv); + RequestWriter.BeginObject("Request"sv); { - for (const IoHash& RawHash : RawHashes) + RequestWriter.BeginArray("Chunks"sv); { - RequestWriter.BeginObject(); + for (const IoHash& RawHash : RawHashes) { - RequestWriter.AddHash("RawHash", RawHash); + RequestWriter.BeginObject(); + { + RequestWriter.AddHash("RawHash", RawHash); + } + RequestWriter.EndObject(); } - RequestWriter.EndObject(); } + RequestWriter.EndArray(); // "chunks" } - RequestWriter.EndArray(); // "chunks" + RequestWriter.EndObject(); + Request = RequestWriter.Save(); } - RequestWriter.EndObject(); - Request = RequestWriter.Save(); - } - HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage)); - AddStats(Response); + HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage)); + AddStats(Response); - LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)}; - if (Result.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'", - RawHashes.size(), - m_ProjectStoreUrl, - m_Project, - m_Oplog, - Result.Reason); - } - else - { - CbPackage Package = Response.AsPackage(); - std::span Attachments = Package.GetAttachments(); - Result.Chunks.reserve(Attachments.size()); - for (const CbAttachment& Attachment : Attachments) + Result = LoadAttachmentsResult{ConvertResult(Response)}; + if (Result.ErrorCode) { - Result.Chunks.emplace_back( - std::pair{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); + Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'", + RawHashes.size(), + m_ProjectStoreUrl, + m_Project, + m_Oplog, + Result.Reason); + } + else + { + CbPackage Package = Response.AsPackage(); + std::span Attachments = Package.GetAttachments(); + Result.Chunks.reserve(Attachments.size()); + for (const CbAttachment& Attachment : Attachments) + { + Result.Chunks.emplace_back( + std::pair{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); + } } } return Result; @@ -260,32 +264,59 @@ public: return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; } - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, const AttachmentRange& Range) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override { - std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); - HttpClient::Response Response = - m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); - AddStats(Response); - - LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)}; - if (Result.ErrorCode) + LoadAttachmentResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", - m_ProjectStoreUrl, - m_Project, - m_Oplog, - RawHash, - Result.Reason); - } - if (!Result.ErrorCode && Range) - { - Result.Bytes = IoBuffer(Response.ResponsePayload, Range.Offset, Range.Bytes); + std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); + HttpClient::Response Response = + m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); + AddStats(Response); + + Result = LoadAttachmentResult{ConvertResult(Response)}; + if (Result.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", + m_ProjectStoreUrl, + m_Project, + m_Oplog, + RawHash, + Result.Reason); + } + Result.Bytes = Response.ResponsePayload; + Result.Bytes.MakeOwned(); } - else + return Result; + } + + virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, + std::span> Ranges, + ESourceMode SourceMode) override + { + LoadAttachmentRangesResult Result; + if (SourceMode != ESourceMode::kCacheOnly) { - Result.Bytes = Response.ResponsePayload; + std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); + HttpClient::Response Response = + m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); + AddStats(Response); + + Result = LoadAttachmentRangesResult{ConvertResult(Response)}; + if (Result.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", + m_ProjectStoreUrl, + m_Project, + m_Oplog, + RawHash, + Result.Reason); + } + else + { + Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); + } } - Result.Bytes.MakeOwned(); return Result; } diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index 6ada085a5..459e044eb 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -233,16 +233,21 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) { const uint64_t MaxBlobSize = Range.first < BlobSize ? BlobSize - Range.first : 0; const uint64_t RangeSize = Min(Range.second, MaxBlobSize); - if (Range.first + RangeSize <= BlobSize) + Writer.BeginObject(); { - RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize)); - Writer.BeginObject(); + if (Range.first + RangeSize <= BlobSize) { + RangeBuffers.push_back(IoBuffer(Blob, Range.first, RangeSize)); Writer.AddInteger("offset"sv, Range.first); Writer.AddInteger("length"sv, RangeSize); } - Writer.EndObject(); + else + { + Writer.AddInteger("offset"sv, Range.first); + Writer.AddInteger("length"sv, 0); + } } + Writer.EndObject(); } Writer.EndArray(); @@ -262,7 +267,16 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) } else { - ZEN_ASSERT(OffsetAndLengthPairs.size() == 1); + if (OffsetAndLengthPairs.size() != 1) + { + // Only a single http range is supported, we have limited support for http multirange responses + m_BuildStoreStats.BadRequestCount++; + return ServerRequest.WriteResponse( + HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Multiple ranges in blob request is only supported for {} accept type", ToString(HttpContentType::kCbPackage))); + } + const std::pair& OffsetAndLength = OffsetAndLengthPairs.front(); const uint64_t BlobSize = Blob.GetSize(); const uint64_t MaxBlobSize = OffsetAndLength.first < BlobSize ? BlobSize - OffsetAndLength.first : 0; -- cgit v1.2.3 From b67dac7c093cc82b7e8f12f9eb57bfa34dfe26d8 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 4 Mar 2026 08:35:32 +0100 Subject: unity build fixes (#802) Various fixes to make cpp files build in unity build mode as an aside using Unity build doesn't really seem to work on Linux, unsure why but it leads to link-time issues --- src/zen/cmds/builds_cmd.cpp | 9 +++++---- src/zen/cmds/projectstore_cmd.cpp | 18 ++++++++++++++---- src/zen/cmds/wipe_cmd.cpp | 6 ++++-- src/zencore/include/zencore/compactbinaryfile.h | 1 + src/zencore/include/zencore/meta.h | 1 + src/zencore/include/zencore/varint.h | 1 + src/zencore/md5.cpp | 19 ++++++++++++++++++- src/zencore/xmake.lua | 1 + src/zenhttp/include/zenhttp/httpapiservice.h | 1 + src/zenremotestore/chunking/chunkblock.cpp | 8 ++++---- .../projectstore/remoteprojectstore.cpp | 6 +++--- src/zenserver/storage/storageconfig.h | 1 + src/zenstore/include/zenstore/buildstore/buildstore.h | 2 +- 13 files changed, 55 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 5254ef3cf..ffdc5fe48 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -67,13 +67,11 @@ ZEN_THIRD_PARTY_INCLUDES_END static const bool DoExtraContentVerify = false; -#define ZEN_CLOUD_STORAGE "Cloud Storage" - namespace zen { using namespace std::literals; -namespace { +namespace builds_impl { static std::atomic AbortFlag = false; static std::atomic PauseFlag = false; @@ -270,6 +268,7 @@ namespace { static bool IsQuiet = false; static ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; +#undef ZEN_CONSOLE_VERBOSE #define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ if (IsVerbose) \ { \ @@ -2009,12 +2008,13 @@ namespace { ProgressBar::SetLogOperationProgress(ProgressMode, TaskSteps::Cleanup, TaskSteps::StepCount); } -} // namespace +} // namespace builds_impl ////////////////////////////////////////////////////////////////////////////////////////////////////// BuildsCommand::BuildsCommand() { + using namespace builds_impl; m_Options.add_options()("h,help", "Print help"); auto AddSystemOptions = [this](cxxopts::Options& Ops) { @@ -2655,6 +2655,7 @@ BuildsCommand::~BuildsCommand() = default; void BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace builds_impl; ZEN_UNUSED(GlobalOptions); signal(SIGINT, SignalCallbackHandler); diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index bedab3cfd..dfc6c1650 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -41,12 +41,10 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { +namespace projectstore_impl { using namespace std::literals; -#define ZEN_CLOUD_STORAGE "Cloud Storage" - void WriteAuthOptions(CbObjectWriter& Writer, std::string_view JupiterOpenIdProvider, std::string_view JupiterAccessToken, @@ -500,7 +498,7 @@ namespace { return {}; } -} // namespace +} // namespace projectstore_impl /////////////////////////////////////// @@ -522,6 +520,7 @@ DropProjectCommand::~DropProjectCommand() void DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -611,6 +610,7 @@ ProjectInfoCommand::~ProjectInfoCommand() void ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -697,6 +697,7 @@ CreateProjectCommand::~CreateProjectCommand() = default; void CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); using namespace std::literals; @@ -766,6 +767,7 @@ CreateOplogCommand::~CreateOplogCommand() = default; void CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); using namespace std::literals; @@ -989,6 +991,7 @@ ExportOplogCommand::~ExportOplogCommand() void ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1495,6 +1498,7 @@ ImportOplogCommand::~ImportOplogCommand() void ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1788,6 +1792,7 @@ SnapshotOplogCommand::~SnapshotOplogCommand() void SnapshotOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; using namespace std::literals; ZEN_UNUSED(GlobalOptions); @@ -1852,6 +1857,7 @@ ProjectStatsCommand::~ProjectStatsCommand() void ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -1904,6 +1910,7 @@ ProjectOpDetailsCommand::~ProjectOpDetailsCommand() void ProjectOpDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2019,6 +2026,7 @@ OplogMirrorCommand::~OplogMirrorCommand() void OplogMirrorCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2286,6 +2294,7 @@ OplogValidateCommand::~OplogValidateCommand() void OplogValidateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) @@ -2437,6 +2446,7 @@ OplogDownloadCommand::~OplogDownloadCommand() void OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace projectstore_impl; ZEN_UNUSED(GlobalOptions); if (!ParseOptions(argc, argv)) diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp index a5029e1c5..fd9e28a80 100644 --- a/src/zen/cmds/wipe_cmd.cpp +++ b/src/zen/cmds/wipe_cmd.cpp @@ -33,7 +33,7 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -namespace { +namespace wipe_impl { static std::atomic AbortFlag = false; static std::atomic PauseFlag = false; static bool IsVerbose = false; @@ -49,6 +49,7 @@ namespace { : GetMediumWorkerPool(EWorkloadType::Burst); } +#undef ZEN_CONSOLE_VERBOSE #define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ if (IsVerbose) \ { \ @@ -505,7 +506,7 @@ namespace { } return CleanWipe; } -} // namespace +} // namespace wipe_impl WipeCommand::WipeCommand() { @@ -532,6 +533,7 @@ WipeCommand::~WipeCommand() = default; void WipeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { + using namespace wipe_impl; ZEN_UNUSED(GlobalOptions); signal(SIGINT, SignalCallbackHandler); diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h index 00c37e941..33f3e7bea 100644 --- a/src/zencore/include/zencore/compactbinaryfile.h +++ b/src/zencore/include/zencore/compactbinaryfile.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include #include diff --git a/src/zencore/include/zencore/meta.h b/src/zencore/include/zencore/meta.h index 82eb5cc30..20ec4ac6f 100644 --- a/src/zencore/include/zencore/meta.h +++ b/src/zencore/include/zencore/meta.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once /* This file contains utility functions for meta programming * diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h index 9fe905f25..43ca14d38 100644 --- a/src/zencore/include/zencore/varint.h +++ b/src/zencore/include/zencore/varint.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include "intmath.h" diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp index 3baee91c2..83ed53fc8 100644 --- a/src/zencore/md5.cpp +++ b/src/zencore/md5.cpp @@ -342,6 +342,23 @@ Transform(uint32_t* buf, uint32_t* in) #undef G #undef H #undef I +#undef ROTATE_LEFT +#undef S11 +#undef S12 +#undef S13 +#undef S14 +#undef S21 +#undef S22 +#undef S23 +#undef S24 +#undef S31 +#undef S32 +#undef S33 +#undef S34 +#undef S41 +#undef S42 +#undef S43 +#undef S44 namespace zen { @@ -391,7 +408,7 @@ MD5::FromHexString(const char* string) { MD5 md5; - ParseHexBytes(string, 40, md5.Hash); + ParseHexBytes(string, 2 * sizeof md5.Hash, md5.Hash); return md5; } diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index 9a67175a0..2f81b7ec8 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -15,6 +15,7 @@ target('zencore') set_configdir("include/zencore") add_files("**.cpp") add_files("trace.cpp", {unity_ignored = true }) + add_files("testing.cpp", {unity_ignored = true }) if has_config("zenrpmalloc") then add_deps("rpmalloc") diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h index 0270973bf..2d384d1d8 100644 --- a/src/zenhttp/include/zenhttp/httpapiservice.h +++ b/src/zenhttp/include/zenhttp/httpapiservice.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include diff --git a/src/zenremotestore/chunking/chunkblock.cpp b/src/zenremotestore/chunking/chunkblock.cpp index 9c3fe8a0b..f80bfc2ba 100644 --- a/src/zenremotestore/chunking/chunkblock.cpp +++ b/src/zenremotestore/chunking/chunkblock.cpp @@ -1037,7 +1037,7 @@ ChunkBlockAnalyser::CalculateBlockRanges(uint32_t BlockIndex, #if ZEN_WITH_TESTS -namespace testutils { +namespace chunkblock_testutils { static std::vector> CreateAttachments( const std::span& Sizes, OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, @@ -1054,14 +1054,14 @@ namespace testutils { return Result; } -} // namespace testutils +} // namespace chunkblock_testutils TEST_SUITE_BEGIN("remotestore.chunkblock"); TEST_CASE("chunkblock.block") { using namespace std::literals; - using namespace testutils; + using namespace chunkblock_testutils; std::vector AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251, @@ -1089,7 +1089,7 @@ TEST_CASE("chunkblock.block") TEST_CASE("chunkblock.reuseblocks") { using namespace std::literals; - using namespace testutils; + using namespace chunkblock_testutils; std::vector> BlockAttachmentSizes( {std::vector{7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 1882f599a..570025b6d 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -4186,7 +4186,7 @@ RemoteProjectStore::~RemoteProjectStore() #if ZEN_WITH_TESTS -namespace testutils { +namespace projectstore_testutils { using namespace std::literals; static std::string OidAsString(const Oid& Id) @@ -4238,7 +4238,7 @@ namespace testutils { return Result; } -} // namespace testutils +} // namespace projectstore_testutils struct ExportForceDisableBlocksTrue_ForceTempBlocksFalse { @@ -4265,7 +4265,7 @@ TEST_CASE_TEMPLATE("project.store.export", ExportForceDisableBlocksFalse_ForceTempBlocksTrue) { using namespace std::literals; - using namespace testutils; + using namespace projectstore_testutils; ScopedTemporaryDirectory TempDir; ScopedTemporaryDirectory ExportDir; diff --git a/src/zenserver/storage/storageconfig.h b/src/zenserver/storage/storageconfig.h index b408b0c26..6124cae14 100644 --- a/src/zenserver/storage/storageconfig.h +++ b/src/zenserver/storage/storageconfig.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include "config/config.h" diff --git a/src/zenstore/include/zenstore/buildstore/buildstore.h b/src/zenstore/include/zenstore/buildstore/buildstore.h index 76cba05b9..bfc83ba0d 100644 --- a/src/zenstore/include/zenstore/buildstore/buildstore.h +++ b/src/zenstore/include/zenstore/buildstore/buildstore.h @@ -1,5 +1,5 @@ - // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include -- cgit v1.2.3 From eafd4d78378c1a642445ed127fdbe51ac559d4e3 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 4 Mar 2026 09:40:49 +0100 Subject: HTTP improvements (#803) - Add GetTotalBytesReceived/GetTotalBytesSent to HttpServer with implementations in ASIO and http.sys backends - Add ExpectedErrorCodes to HttpClientSettings to suppress warn/info logs for anticipated HTTP error codes - Also fixes minor issues in `CprHttpClient::Download` --- src/zencore/process.cpp | 14 ++++++++++++ src/zencore/windows.cpp | 12 +++++----- src/zenhttp/clients/httpclientcpr.cpp | 38 +++++++++++++++++++++++--------- src/zenhttp/clients/httpclientcpr.h | 1 + src/zenhttp/include/zenhttp/httpclient.h | 5 +++++ src/zenhttp/include/zenhttp/httpserver.h | 4 ++++ src/zenhttp/servers/httpasio.cpp | 21 ++++++++++++++++++ src/zenhttp/servers/httpsys.cpp | 23 ++++++++++++++++++- 8 files changed, 99 insertions(+), 19 deletions(-) (limited to 'src') diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 226a94050..f657869dc 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -745,6 +746,8 @@ CreateProcElevated(const std::filesystem::path& Executable, std::string_view Com CreateProcResult CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) { + ZEN_TRACE_CPU("CreateProc"); + #if ZEN_PLATFORM_WINDOWS if (Options.Flags & CreateProcOptions::Flag_Unelevated) { @@ -776,6 +779,17 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine ZEN_UNUSED(Result); } + if (!Options.StdoutFile.empty()) + { + int Fd = open(Options.StdoutFile.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (Fd >= 0) + { + dup2(Fd, STDOUT_FILENO); + dup2(Fd, STDERR_FILENO); + close(Fd); + } + } + if (execv(Executable.c_str(), ArgV.data()) < 0) { ThrowLastError("Failed to exec() a new process image"); diff --git a/src/zencore/windows.cpp b/src/zencore/windows.cpp index d02fcd35e..87f854b90 100644 --- a/src/zencore/windows.cpp +++ b/src/zencore/windows.cpp @@ -12,14 +12,12 @@ namespace zen::windows { bool IsRunningOnWine() { - HMODULE NtDll = GetModuleHandleA("ntdll.dll"); + static bool s_Result = [] { + HMODULE NtDll = GetModuleHandleA("ntdll.dll"); + return NtDll && !!GetProcAddress(NtDll, "wine_get_version"); + }(); - if (NtDll) - { - return !!GetProcAddress(NtDll, "wine_get_version"); - } - - return false; + return s_Result; } FileMapping::FileMapping(_In_ FileMapping& orig) diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 90dcfacbb..14e40b02a 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace zen { @@ -164,6 +165,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri, { } +bool +CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + // Quiet + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); @@ -193,11 +206,9 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, ResponseBuffer.SetContentType(ContentType); } - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - - if (!Quiet) + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + if (ShouldLogErrorCode(WorkResponseCode)) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } @@ -371,8 +382,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -410,8 +420,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -646,7 +655,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke ResponseBuffer.SetContentType(ContentType); } - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -929,6 +938,13 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF cpr::Response Response = DoWithRetry( m_SessionId, [&]() { + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + auto DownloadCallback = [&](std::string data, intptr_t) { if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) { @@ -969,7 +985,7 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF if (RangeStartPos != std::string::npos) { RangeStartPos++; - while (RangeValue[RangeStartPos] == ' ') + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') { RangeStartPos++; } @@ -991,7 +1007,7 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::optional RequestedRangeEnd = ParseInt(RangeString.substr(RangeSplitPos + 1)); if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) { - RequestedContentLength += RequestedRangeEnd.value() - 1; + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; } } RangeStartPos = RangeEnd; diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index cf2d3bd14..752d91add 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -155,6 +155,7 @@ private: std::function&& Func, std::function&& Validate = [](cpr::Response&) { return true; }); + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; bool ValidatePayload(cpr::Response& Response, std::unique_ptr& PayloadFile); HttpClient::Response CommonResponse(std::string_view SessionId, diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 53be36b9a..d87082d10 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace zen { @@ -58,6 +59,10 @@ struct HttpClientSettings Oid SessionId = Oid::Zero; bool Verbose = false; uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u; + + /// HTTP status codes that are expected and should not be logged as warnings. + /// 404 is always treated as expected regardless of this list. + std::vector ExpectedErrorCodes; }; class HttpClientError : public std::runtime_error diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index fee932daa..62c080a7b 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -230,6 +230,10 @@ public: */ std::string_view GetExternalHost() const { return m_ExternalHost; } + /** Returns total bytes received and sent across all connections since server start. */ + virtual uint64_t GetTotalBytesReceived() const { return 0; } + virtual uint64_t GetTotalBytesSent() const { return 0; } + private: std::vector m_KnownServices; int m_EffectivePort = 0; diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 8c2dcd116..c4d9ee777 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -528,6 +528,9 @@ public: RwLock m_Lock; std::vector m_UriHandlers; + + std::atomic m_TotalBytesReceived{0}; + std::atomic m_TotalBytesSent{0}; }; /** @@ -1043,6 +1046,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } + m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), @@ -1096,6 +1101,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, return; } + m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, RequestNumber, @@ -2053,6 +2060,8 @@ public: virtual void OnRequestExit() override; virtual void OnClose() override; virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; private: Event m_ShutdownEvent; @@ -2150,6 +2159,18 @@ HttpAsioServer::OnGetExternalHost() const } } +uint64_t +HttpAsioServer::GetTotalBytesReceived() const +{ + return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpAsioServer::GetTotalBytesSent() const +{ + return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpAsioServer::OnRun(bool IsInteractive) { diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 23d57af57..a48f1d316 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -88,6 +88,8 @@ class HttpSysServerRequest; class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -102,6 +104,8 @@ public: virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnClose() override; virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); @@ -149,6 +153,9 @@ private: RwLock m_RequestFilterLock; std::atomic m_HttpRequestFilter = nullptr; + + std::atomic m_TotalBytesReceived{0}; + std::atomic m_TotalBytesSent{0}; }; } // namespace zen @@ -591,7 +598,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -2123,6 +2130,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2401,6 +2410,18 @@ HttpSysServer::OnGetExternalHost() const } } +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { -- cgit v1.2.3 From 6e51634c31cfbe6ad99e27bcefe7ec3bd06dd5c5 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Wed, 4 Mar 2026 13:58:26 +0100 Subject: IterateChunks callback is multithreaded - make sure AttachmentsSize can handle it (#804) --- src/zenserver/storage/cache/httpstructuredcache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index 72f29d14e..00151f79e 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -654,7 +654,7 @@ HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Reque auto NewEnd = std::unique(AllAttachments.begin(), AllAttachments.end()); AllAttachments.erase(NewEnd, AllAttachments.end()); - uint64_t AttachmentsSize = 0; + std::atomic AttachmentsSize = 0; m_CidStore.IterateChunks( AllAttachments, @@ -746,7 +746,7 @@ HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, ResponseWriter << "Size" << ValuesSize; ResponseWriter << "AttachmentCount" << ContentStats.Attachments.size(); - uint64_t AttachmentsSize = 0; + std::atomic AttachmentsSize = 0; WorkerThreadPool& WorkerPool = GetMediumWorkerPool(EWorkloadType::Background); -- cgit v1.2.3 From 0763d09a81e5a1d3df11763a7ec75e7860c9510a Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 4 Mar 2026 14:13:46 +0100 Subject: compute orchestration (#763) - Added local process runners for Linux/Wine, Mac with some sandboxing support - Horde & Nomad provisioning for development and testing - Client session queues with lifecycle management (active/draining/cancelled), automatic retry with configurable limits, and manual reschedule API - Improved web UI for orchestrator, compute, and hub dashboards with WebSocket push updates - Some security hardening - Improved scalability and `zen exec` command Still experimental - compute support is disabled by default --- src/zen/cmds/exec_cmd.cpp | 753 ++++++- src/zen/cmds/exec_cmd.h | 8 +- src/zencompute/CLAUDE.md | 232 ++ src/zencompute/actionrecorder.cpp | 258 --- src/zencompute/actionrecorder.h | 91 - src/zencompute/cloudmetadata.cpp | 1010 +++++++++ src/zencompute/computeservice.cpp | 2236 ++++++++++++++++++++ src/zencompute/functionrunner.cpp | 112 - src/zencompute/functionrunner.h | 207 -- src/zencompute/functionservice.cpp | 957 --------- src/zencompute/httpcomputeservice.cpp | 1643 ++++++++++++++ src/zencompute/httpfunctionservice.cpp | 709 ------- src/zencompute/httporchestrator.cpp | 621 +++++- src/zencompute/include/zencompute/cloudmetadata.h | 151 ++ src/zencompute/include/zencompute/computeservice.h | 262 +++ .../include/zencompute/functionservice.h | 132 -- .../include/zencompute/httpcomputeservice.h | 54 + .../include/zencompute/httpfunctionservice.h | 73 - .../include/zencompute/httporchestrator.h | 81 +- src/zencompute/include/zencompute/mockimds.h | 102 + .../include/zencompute/orchestratorservice.h | 177 ++ .../include/zencompute/recordingreader.h | 4 +- src/zencompute/include/zencompute/zencompute.h | 4 + src/zencompute/localrunner.cpp | 722 ------- src/zencompute/localrunner.h | 100 - src/zencompute/orchestratorservice.cpp | 710 +++++++ src/zencompute/recording/actionrecorder.cpp | 258 +++ src/zencompute/recording/actionrecorder.h | 91 + src/zencompute/recording/recordingreader.cpp | 335 +++ src/zencompute/recordingreader.cpp | 335 --- src/zencompute/remotehttprunner.cpp | 457 ---- src/zencompute/remotehttprunner.h | 80 - src/zencompute/runners/deferreddeleter.cpp | 336 +++ src/zencompute/runners/deferreddeleter.h | 68 + src/zencompute/runners/functionrunner.cpp | 365 ++++ src/zencompute/runners/functionrunner.h | 214 ++ src/zencompute/runners/linuxrunner.cpp | 734 +++++++ src/zencompute/runners/linuxrunner.h | 44 + src/zencompute/runners/localrunner.cpp | 674 ++++++ src/zencompute/runners/localrunner.h | 138 ++ src/zencompute/runners/macrunner.cpp | 491 +++++ src/zencompute/runners/macrunner.h | 43 + src/zencompute/runners/remotehttprunner.cpp | 618 ++++++ src/zencompute/runners/remotehttprunner.h | 100 + src/zencompute/runners/windowsrunner.cpp | 460 ++++ src/zencompute/runners/windowsrunner.h | 53 + src/zencompute/runners/winerunner.cpp | 237 +++ src/zencompute/runners/winerunner.h | 37 + src/zencompute/testing/mockimds.cpp | 205 ++ src/zencompute/timeline/workertimeline.cpp | 430 ++++ src/zencompute/timeline/workertimeline.h | 169 ++ src/zencompute/xmake.lua | 10 + src/zencompute/zencompute.cpp | 9 + src/zencore/include/zencore/system.h | 36 +- src/zencore/system.cpp | 336 ++- src/zenhorde/hordeagent.cpp | 297 +++ src/zenhorde/hordeagent.h | 77 + src/zenhorde/hordeagentmessage.cpp | 340 +++ src/zenhorde/hordeagentmessage.h | 161 ++ src/zenhorde/hordebundle.cpp | 619 ++++++ src/zenhorde/hordebundle.h | 49 + src/zenhorde/hordeclient.cpp | 382 ++++ src/zenhorde/hordecomputebuffer.cpp | 454 ++++ src/zenhorde/hordecomputebuffer.h | 136 ++ src/zenhorde/hordecomputechannel.cpp | 37 + src/zenhorde/hordecomputechannel.h | 32 + src/zenhorde/hordecomputesocket.cpp | 204 ++ src/zenhorde/hordecomputesocket.h | 79 + src/zenhorde/hordeconfig.cpp | 89 + src/zenhorde/hordeprovisioner.cpp | 367 ++++ src/zenhorde/hordetransport.cpp | 169 ++ src/zenhorde/hordetransport.h | 71 + src/zenhorde/hordetransportaes.cpp | 425 ++++ src/zenhorde/hordetransportaes.h | 52 + src/zenhorde/include/zenhorde/hordeclient.h | 116 + src/zenhorde/include/zenhorde/hordeconfig.h | 62 + src/zenhorde/include/zenhorde/hordeprovisioner.h | 110 + src/zenhorde/include/zenhorde/zenhorde.h | 9 + src/zenhorde/xmake.lua | 22 + src/zenhttp/include/zenhttp/httpserver.h | 1 + src/zenhttp/servers/httpasio.cpp | 25 +- src/zenhttp/servers/httpsys.cpp | 25 +- src/zennomad/include/zennomad/nomadclient.h | 77 + src/zennomad/include/zennomad/nomadconfig.h | 65 + src/zennomad/include/zennomad/nomadprocess.h | 78 + src/zennomad/include/zennomad/nomadprovisioner.h | 107 + src/zennomad/include/zennomad/zennomad.h | 9 + src/zennomad/nomadclient.cpp | 366 ++++ src/zennomad/nomadconfig.cpp | 91 + src/zennomad/nomadprocess.cpp | 354 ++++ src/zennomad/nomadprovisioner.cpp | 264 +++ src/zennomad/xmake.lua | 10 + .../builds/buildstorageoperations.cpp | 2 +- src/zenremotestore/chunking/chunkingcache.cpp | 8 +- src/zenserver-test/compute-tests.cpp | 1700 +++++++++++++++ src/zenserver-test/function-tests.cpp | 38 - src/zenserver-test/logging-tests.cpp | 257 +++ src/zenserver-test/nomad-tests.cpp | 126 ++ src/zenserver-test/xmake.lua | 7 +- src/zenserver/compute/computeserver.cpp | 725 ++++++- src/zenserver/compute/computeserver.h | 111 +- src/zenserver/compute/computeservice.cpp | 100 - src/zenserver/compute/computeservice.h | 36 - src/zenserver/frontend/html.zip | Bin 238188 -> 319315 bytes src/zenserver/frontend/html/404.html | 486 +++++ src/zenserver/frontend/html/compute.html | 991 --------- src/zenserver/frontend/html/compute/banner.js | 321 +++ src/zenserver/frontend/html/compute/compute.html | 1072 ++++++++++ src/zenserver/frontend/html/compute/hub.html | 310 +++ src/zenserver/frontend/html/compute/index.html | 1 + src/zenserver/frontend/html/compute/nav.js | 79 + .../frontend/html/compute/orchestrator.html | 831 ++++++++ src/zenserver/frontend/html/pages/page.js | 36 + src/zenserver/frontend/html/zen.css | 27 + src/zenserver/hub/hubservice.cpp | 2 +- src/zenserver/hub/zenhubserver.cpp | 7 + src/zenserver/hub/zenhubserver.h | 6 +- src/zenserver/storage/zenstorageserver.cpp | 17 +- src/zenserver/storage/zenstorageserver.h | 4 +- src/zenserver/trace/tracerecorder.cpp | 565 +++++ src/zenserver/trace/tracerecorder.h | 46 + src/zenserver/xmake.lua | 19 + src/zentest-appstub/zentest-appstub.cpp | 11 + src/zenutil/include/zenutil/consoletui.h | 1 + src/zenutil/include/zenutil/zenserverprocess.h | 1 + src/zenutil/zenserverprocess.cpp | 6 + 126 files changed, 26956 insertions(+), 5596 deletions(-) create mode 100644 src/zencompute/CLAUDE.md delete mode 100644 src/zencompute/actionrecorder.cpp delete mode 100644 src/zencompute/actionrecorder.h create mode 100644 src/zencompute/cloudmetadata.cpp create mode 100644 src/zencompute/computeservice.cpp delete mode 100644 src/zencompute/functionrunner.cpp delete mode 100644 src/zencompute/functionrunner.h delete mode 100644 src/zencompute/functionservice.cpp create mode 100644 src/zencompute/httpcomputeservice.cpp delete mode 100644 src/zencompute/httpfunctionservice.cpp create mode 100644 src/zencompute/include/zencompute/cloudmetadata.h create mode 100644 src/zencompute/include/zencompute/computeservice.h delete mode 100644 src/zencompute/include/zencompute/functionservice.h create mode 100644 src/zencompute/include/zencompute/httpcomputeservice.h delete mode 100644 src/zencompute/include/zencompute/httpfunctionservice.h create mode 100644 src/zencompute/include/zencompute/mockimds.h create mode 100644 src/zencompute/include/zencompute/orchestratorservice.h delete mode 100644 src/zencompute/localrunner.cpp delete mode 100644 src/zencompute/localrunner.h create mode 100644 src/zencompute/orchestratorservice.cpp create mode 100644 src/zencompute/recording/actionrecorder.cpp create mode 100644 src/zencompute/recording/actionrecorder.h create mode 100644 src/zencompute/recording/recordingreader.cpp delete mode 100644 src/zencompute/recordingreader.cpp delete mode 100644 src/zencompute/remotehttprunner.cpp delete mode 100644 src/zencompute/remotehttprunner.h create mode 100644 src/zencompute/runners/deferreddeleter.cpp create mode 100644 src/zencompute/runners/deferreddeleter.h create mode 100644 src/zencompute/runners/functionrunner.cpp create mode 100644 src/zencompute/runners/functionrunner.h create mode 100644 src/zencompute/runners/linuxrunner.cpp create mode 100644 src/zencompute/runners/linuxrunner.h create mode 100644 src/zencompute/runners/localrunner.cpp create mode 100644 src/zencompute/runners/localrunner.h create mode 100644 src/zencompute/runners/macrunner.cpp create mode 100644 src/zencompute/runners/macrunner.h create mode 100644 src/zencompute/runners/remotehttprunner.cpp create mode 100644 src/zencompute/runners/remotehttprunner.h create mode 100644 src/zencompute/runners/windowsrunner.cpp create mode 100644 src/zencompute/runners/windowsrunner.h create mode 100644 src/zencompute/runners/winerunner.cpp create mode 100644 src/zencompute/runners/winerunner.h create mode 100644 src/zencompute/testing/mockimds.cpp create mode 100644 src/zencompute/timeline/workertimeline.cpp create mode 100644 src/zencompute/timeline/workertimeline.h create mode 100644 src/zenhorde/hordeagent.cpp create mode 100644 src/zenhorde/hordeagent.h create mode 100644 src/zenhorde/hordeagentmessage.cpp create mode 100644 src/zenhorde/hordeagentmessage.h create mode 100644 src/zenhorde/hordebundle.cpp create mode 100644 src/zenhorde/hordebundle.h create mode 100644 src/zenhorde/hordeclient.cpp create mode 100644 src/zenhorde/hordecomputebuffer.cpp create mode 100644 src/zenhorde/hordecomputebuffer.h create mode 100644 src/zenhorde/hordecomputechannel.cpp create mode 100644 src/zenhorde/hordecomputechannel.h create mode 100644 src/zenhorde/hordecomputesocket.cpp create mode 100644 src/zenhorde/hordecomputesocket.h create mode 100644 src/zenhorde/hordeconfig.cpp create mode 100644 src/zenhorde/hordeprovisioner.cpp create mode 100644 src/zenhorde/hordetransport.cpp create mode 100644 src/zenhorde/hordetransport.h create mode 100644 src/zenhorde/hordetransportaes.cpp create mode 100644 src/zenhorde/hordetransportaes.h create mode 100644 src/zenhorde/include/zenhorde/hordeclient.h create mode 100644 src/zenhorde/include/zenhorde/hordeconfig.h create mode 100644 src/zenhorde/include/zenhorde/hordeprovisioner.h create mode 100644 src/zenhorde/include/zenhorde/zenhorde.h create mode 100644 src/zenhorde/xmake.lua create mode 100644 src/zennomad/include/zennomad/nomadclient.h create mode 100644 src/zennomad/include/zennomad/nomadconfig.h create mode 100644 src/zennomad/include/zennomad/nomadprocess.h create mode 100644 src/zennomad/include/zennomad/nomadprovisioner.h create mode 100644 src/zennomad/include/zennomad/zennomad.h create mode 100644 src/zennomad/nomadclient.cpp create mode 100644 src/zennomad/nomadconfig.cpp create mode 100644 src/zennomad/nomadprocess.cpp create mode 100644 src/zennomad/nomadprovisioner.cpp create mode 100644 src/zennomad/xmake.lua create mode 100644 src/zenserver-test/compute-tests.cpp delete mode 100644 src/zenserver-test/function-tests.cpp create mode 100644 src/zenserver-test/logging-tests.cpp create mode 100644 src/zenserver-test/nomad-tests.cpp delete mode 100644 src/zenserver/compute/computeservice.cpp delete mode 100644 src/zenserver/compute/computeservice.h create mode 100644 src/zenserver/frontend/html/404.html delete mode 100644 src/zenserver/frontend/html/compute.html create mode 100644 src/zenserver/frontend/html/compute/banner.js create mode 100644 src/zenserver/frontend/html/compute/compute.html create mode 100644 src/zenserver/frontend/html/compute/hub.html create mode 100644 src/zenserver/frontend/html/compute/index.html create mode 100644 src/zenserver/frontend/html/compute/nav.js create mode 100644 src/zenserver/frontend/html/compute/orchestrator.html create mode 100644 src/zenserver/trace/tracerecorder.cpp create mode 100644 src/zenserver/trace/tracerecorder.h (limited to 'src') diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 407f42ee3..42c7119e7 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -2,7 +2,7 @@ #include "exec_cmd.h" -#include +#include #include #include #include @@ -14,9 +14,13 @@ #include #include #include +#include #include #include +#include #include +#include +#include #include #include @@ -47,12 +51,17 @@ ExecCommand::ExecCommand() m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), ""); m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), ""); m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), ""); + m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), ""); m_Options.add_option("", "", "mode", "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", cxxopts::value(m_Mode)->default_value("http"), ""); + m_Options + .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), ""); + m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), ""); + m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), ""); m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), ""); m_Options.parse_positional("mode"); } @@ -236,16 +245,16 @@ ExecCommand::InProcessExecute() ZEN_ASSERT(m_ChunkResolver); ChunkResolver& Resolver = *m_ChunkResolver; - zen::compute::FunctionServiceSession FunctionSession(Resolver); + zen::compute::ComputeServiceSession ComputeSession(Resolver); std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - FunctionSession.AddLocalRunner(Resolver, TempPath); + ComputeSession.AddLocalRunner(Resolver, TempPath); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } int -ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession) +ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession) { struct JobTracker { @@ -281,6 +290,117 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess JobTracker PendingJobs; + struct ActionSummaryEntry + { + int32_t Lsn = 0; + int RecordingIndex = 0; + IoHash ActionId; + std::string FunctionName; + int InputAttachments = 0; + uint64_t InputBytes = 0; + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + float WallSeconds = 0.0f; + float CpuSeconds = 0.0f; + uint64_t SubmittedTicks = 0; + uint64_t StartedTicks = 0; + std::string ExecutionLocation; + }; + + std::mutex SummaryLock; + std::unordered_map SummaryEntries; + + ComputeSession.WaitUntilReady(); + + // Register as a client with the orchestrator (best-effort) + + std::string OrchestratorClientId; + + if (!m_OrchestratorUrl.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + + CbObjectWriter Ann; + Ann << "session_id"sv << GetSessionId(); + Ann << "hostname"sv << std::string_view(GetMachineName()); + + CbObjectWriter Meta; + Meta << "source"sv + << "zen-exec"sv; + Ann << "metadata"sv << Meta.Save(); + + auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); + if (Resp.IsSuccess()) + { + OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString()); + ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId); + } + else + { + ZEN_WARN("failed to register with orchestrator (status {})", static_cast(Resp.StatusCode)); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + } + } + + Stopwatch OrchestratorHeartbeatTimer; + + auto SendOrchestratorHeartbeat = [&] { + if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) + { + return; + } + OrchestratorHeartbeatTimer.Reset(); + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId)); + } + catch (...) + { + } + }; + + auto ClientCleanup = MakeGuard([&] { + if (!OrchestratorClientId.empty()) + { + try + { + HttpClient OrchestratorClient(m_OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId)); + } + catch (...) + { + } + } + }); + + // Create a queue to group all actions from this exec session + + CbObjectWriter Metadata; + Metadata << "source"sv + << "zen-exec"sv; + + auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save()); + const int QueueId = QueueResult.QueueId; + if (!QueueId) + { + ZEN_ERROR("failed to create compute queue"); + return 1; + } + + auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); }); + + if (!m_OutputPath.empty()) + { + zen::CreateDirectories(m_OutputPath); + } + std::atomic IsDraining{0}; auto DrainCompletedJobs = [&] { @@ -292,7 +412,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); CbObjectWriter Cbo; - FunctionSession.GetCompleted(Cbo); + ComputeSession.GetQueueCompleted(QueueId, Cbo); if (CbObject Completed = Cbo.Save()) { @@ -301,10 +421,89 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess int32_t CompleteLsn = It.AsInt32(); CbPackage ResultPackage; - HttpResponseCode Response = FunctionSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); if (Response == HttpResponseCode::OK) { + if (!m_OutputPath.empty() && ResultPackage) + { + int OutputAttachments = 0; + uint64_t OutputBytes = 0; + + if (!m_Binary) + { + // Write the root object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed attachments + auto Attachments = ResultPackage.GetAttachments(); + + if (!Attachments.empty()) + { + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn); + zen::CreateDirectories(AttDir); + + for (const CbAttachment& Att : Attachments) + { + ++OutputAttachments; + + IoHash AttHash = Att.GetHash(); + + if (Att.IsCompressedBinary()) + { + SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); + OutputBytes += Decompressed.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + else + { + SharedBuffer Binary = Att.AsBinary(); + OutputBytes += Binary.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); + } + } + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", + m_OutputPath.string(), + CompleteLsn, + OutputAttachments); + } + } + else + { + CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); + + for (const CbAttachment& Att : ResultPackage.GetAttachments()) + { + ++OutputAttachments; + OutputBytes += Att.AsBinary().GetSize(); + } + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn); + } + } + + std::lock_guard Lock(SummaryLock); + if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end()) + { + It2->second.OutputAttachments = OutputAttachments; + It2->second.OutputBytes = OutputBytes; + } + } + PendingJobs.Remove(CompleteLsn); ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); @@ -321,7 +520,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess { CbPackage WorkerDesc = Kv.second; - FunctionSession.RegisterWorker(WorkerDesc); + ComputeSession.RegisterWorker(WorkerDesc); } // Then submit work items @@ -367,10 +566,14 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess TargetParallelism = 1; } + std::atomic RecordingIndex{0}; + m_RecordingReader->IterateActions( [&](CbObject ActionObject, const IoHash& ActionId) { // Enqueue job + const int CurrentRecordingIndex = RecordingIndex++; + Stopwatch SubmitTimer; const int Priority = 0; @@ -404,8 +607,29 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess ObjStr); } - if (zen::compute::FunctionServiceSession::EnqueueResult EnqueueResult = - FunctionSession.EnqueueAction(ActionObject, Priority)) + if (m_DumpActions) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ++AttachmentCount; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + } + + if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = + ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority)) { const int32_t LsnField = EnqueueResult.Lsn; @@ -421,6 +645,96 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess RemainingWorkItems); } + if (!m_OutputPath.empty()) + { + ActionSummaryEntry Entry; + Entry.Lsn = LsnField; + Entry.RecordingIndex = CurrentRecordingIndex; + Entry.ActionId = ActionId; + Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); + + if (!m_Binary) + { + // Write action object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ActionObject, YamlStr); + + std::string_view Yaml = YamlStr; + zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); + + // Write decompressed input attachments + std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField); + bool AttDirCreated = false; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + SharedBuffer Decompressed = Compressed.Decompress(); + + Entry.InputBytes += Decompressed.GetSize(); + + if (!AttDirCreated) + { + zen::CreateDirectories(AttDir); + AttDirCreated = true; + } + + zen::WriteFile(AttDir / AttachCid.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + }); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", + m_OutputPath.string(), + LsnField, + Entry.InputAttachments); + } + } + else + { + // Build a CbPackage from the action and write as .pkg + CbPackage ActionPackage; + ActionPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; + + if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + + Entry.InputBytes += ChunkData.GetSize(); + ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); + } + }); + + CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); + zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized)); + + if (!m_QuietLogging) + { + ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField); + } + } + + std::lock_guard Lock(SummaryLock); + SummaryEntries.emplace(LsnField, std::move(Entry)); + } + PendingJobs.Insert(LsnField); } else @@ -450,6 +764,7 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess // Check for completed work DrainCompletedJobs(); + SendOrchestratorHeartbeat(); }, TargetParallelism); @@ -461,6 +776,394 @@ ExecCommand::ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSess zen::Sleep(500); DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + } + + // Merge timing data from queue history into summary entries + + if (!SummaryEntries.empty()) + { + // RunnerAction::State indices (can't include functionrunner.h from here) + constexpr int kStateNew = 0; + constexpr int kStatePending = 1; + constexpr int kStateRunning = 3; + constexpr int kStateCompleted = 4; // first terminal state + constexpr int kStateCount = 8; + + for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0)) + { + std::lock_guard Lock(SummaryLock); + if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end()) + { + // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) + uint64_t EndTick = 0; + for (int S = kStateCompleted; S < kStateCount; ++S) + { + if (HistEntry.Timestamps[S] != 0) + { + EndTick = HistEntry.Timestamps[S]; + break; + } + } + uint64_t StartTick = HistEntry.Timestamps[kStateNew]; + if (EndTick > StartTick) + { + It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); + } + It->second.CpuSeconds = HistEntry.CpuSeconds; + It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; + It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; + It->second.ExecutionLocation = HistEntry.ExecutionLocation; + } + } + } + + // Write summary file if output path is set + + if (!m_OutputPath.empty() && !SummaryEntries.empty()) + { + std::vector Sorted; + Sorted.reserve(SummaryEntries.size()); + for (auto& [_, Entry] : SummaryEntries) + { + Sorted.push_back(std::move(Entry)); + } + + std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { + return A.RecordingIndex < B.RecordingIndex; + }); + + auto FormatTimestamp = [](uint64_t Ticks) -> std::string { + if (Ticks == 0) + { + return "-"; + } + return DateTime(Ticks).ToString("%H:%M:%S.%s"); + }; + + ExtendableStringBuilder<4096> Summary; + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", + "LSN", + "Index", + "ActionId", + "Function", + "InAtt", + "InBytes", + "OutAtt", + "OutBytes", + "Wall(s)", + "CPU(s)", + "Submitted", + "Started", + "Location")); + Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "")); + + for (const ActionSummaryEntry& Entry : Sorted) + { + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", + Entry.Lsn, + Entry.RecordingIndex, + Entry.ActionId, + Entry.FunctionName, + Entry.InputAttachments, + NiceBytes(Entry.InputBytes), + Entry.OutputAttachments, + NiceBytes(Entry.OutputBytes), + Entry.WallSeconds, + Entry.CpuSeconds, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + Entry.ExecutionLocation)); + } + + std::filesystem::path SummaryPath = m_OutputPath / "summary.txt"; + std::string_view SummaryStr = Summary; + zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); + + ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); + + if (!m_Binary) + { + auto EscapeHtml = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '&': + Out += "&"; + break; + case '<': + Out += "<"; + break; + case '>': + Out += ">"; + break; + case '"': + Out += """; + break; + case '\'': + Out += "'"; + break; + default: + Out += C; + } + } + return Out; + }; + + auto EscapeJson = [](std::string_view Input) -> std::string { + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) + { + switch (C) + { + case '"': + Out += "\\\""; + break; + case '\\': + Out += "\\\\"; + break; + case '\n': + Out += "\\n"; + break; + case '\r': + Out += "\\r"; + break; + case '\t': + Out += "\\t"; + break; + default: + if (static_cast(C) < 0x20) + { + Out += fmt::format("\\u{:04x}", static_cast(static_cast(C))); + } + else + { + Out += C; + } + } + } + return Out; + }; + + ExtendableStringBuilder<8192> Html; + + Html.Append(std::string_view(R"( +Exec Summary + +

Exec Summary

+ +
+ + + + + + + + + + + + + + + + + +
LSN Index Action ID Function In Attachments In Bytes Out Attachments Out Bytes Wall(s) CPU(s) Submitted Started Location
+ +)JS")); + + std::filesystem::path HtmlPath = m_OutputPath / "summary.html"; + std::string_view HtmlStr = Html; + zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + + ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + } } if (FailedWorkCounter) @@ -491,10 +1194,10 @@ ExecCommand::HttpExecute() std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } int @@ -504,11 +1207,21 @@ ExecCommand::BeaconExecute() ChunkResolver& Resolver = *m_ChunkResolver; std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); - // FunctionSession.AddRemoteRunner(Resolver, TempPath, "http://10.99.9.246:8558"); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + + if (!m_OrchestratorUrl.empty()) + { + ZEN_CONSOLE_INFO("using orchestrator at {}", m_OrchestratorUrl); + ComputeSession.SetOrchestratorEndpoint(m_OrchestratorUrl); + ComputeSession.SetOrchestratorBasePath(TempPath); + } + else + { + ZEN_CONSOLE_INFO("note: using hard-coded local worker path"); + ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); + } - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } ////////////////////////////////////////////////////////////////////////// @@ -635,10 +1348,10 @@ ExecCommand::BuildActionsLog() std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - zen::compute::FunctionServiceSession FunctionSession(Resolver); - FunctionSession.StartRecording(Resolver, m_RecordingLogPath); + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.StartRecording(Resolver, m_RecordingLogPath); - return ExecUsingSession(FunctionSession); + return ExecUsingSession(ComputeSession); } void diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index 43d092144..6311354c0 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -23,7 +23,7 @@ class ChunkResolver; #if ZEN_WITH_COMPUTE_SERVICES namespace zen::compute { -class FunctionServiceSession; +class ComputeServiceSession; } namespace zen { @@ -49,6 +49,7 @@ public: private: cxxopts::Options m_Options{Name, Description}; std::string m_HostName; + std::string m_OrchestratorUrl; std::filesystem::path m_BeaconPath; std::filesystem::path m_RecordingPath; std::filesystem::path m_RecordingLogPath; @@ -57,6 +58,8 @@ private: int m_Limit = 0; bool m_Quiet = false; std::string m_Mode{"http"}; + std::filesystem::path m_OutputPath; + bool m_Binary = false; struct FunctionDefinition { @@ -74,13 +77,14 @@ private: std::vector m_FunctionList; bool m_VerboseLogging = false; bool m_QuietLogging = false; + bool m_DumpActions = false; zen::ChunkResolver* m_ChunkResolver = nullptr; zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); - int ExecUsingSession(zen::compute::FunctionServiceSession& FunctionSession); + int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession); // Execution modes diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md new file mode 100644 index 000000000..f5188123f --- /dev/null +++ b/src/zencompute/CLAUDE.md @@ -0,0 +1,232 @@ +# zencompute Module + +Lambda-style compute function service. Accepts execution requests from HTTP clients, schedules them across local and remote runners, and tracks results. + +## Directory Structure + +``` +src/zencompute/ +├── include/zencompute/ # Public headers +│ ├── computeservice.h # ComputeServiceSession public API +│ ├── httpcomputeservice.h # HTTP service wrapper +│ ├── orchestratorservice.h # Worker registry and orchestration +│ ├── httporchestrator.h # HTTP orchestrator with WebSocket push +│ ├── recordingreader.h # Recording/replay reader API +│ ├── cloudmetadata.h # Cloud provider detection (AWS/Azure/GCP) +│ └── mockimds.h # Test helper for cloud metadata +├── runners/ # Execution backends +│ ├── functionrunner.h/.cpp # Abstract base + BaseRunnerGroup/RunnerGroup +│ ├── localrunner.h/.cpp # LocalProcessRunner (sandbox, monitoring, CPU sampling) +│ ├── windowsrunner.h/.cpp # Windows AppContainer sandboxing + CreateProcessW +│ ├── linuxrunner.h/.cpp # Linux user/mount/network namespace isolation +│ ├── macrunner.h/.cpp # macOS Seatbelt sandboxing +│ ├── winerunner.h/.cpp # Wine runner for Windows executables on Linux +│ ├── remotehttprunner.h/.cpp # Remote HTTP submission to other zenserver instances +│ └── deferreddeleter.h/.cpp # Background deletion of sandbox directories +├── recording/ # Recording/replay subsystem +│ ├── actionrecorder.h/.cpp # Write actions+attachments to disk +│ └── recordingreader.cpp # Read recordings back +├── timeline/ +│ └── workertimeline.h/.cpp # Per-worker action lifecycle event tracking +├── testing/ +│ └── mockimds.cpp # Mock IMDS for cloud metadata tests +├── computeservice.cpp # ComputeServiceSession::Impl (~1700 lines) +├── httpcomputeservice.cpp # HTTP route registration and handlers (~900 lines) +├── httporchestrator.cpp # Orchestrator HTTP API + WebSocket push +├── orchestratorservice.cpp # Worker registry, health probing +└── cloudmetadata.cpp # IMDS probing, termination monitoring +``` + +## Key Classes + +### `ComputeServiceSession` (computeservice.h) +Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: +- Two `RunnerGroup`s: `m_LocalRunnerGroup`, `m_RemoteRunnerGroup` +- Scheduler thread that drains `m_UpdatedActions` and drives state transitions +- Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` +- Queue map: `m_Queues` (QueueEntry objects) +- Action history ring: `m_ActionHistory` (bounded deque, default 1000) + +**Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. + +### `RunnerAction` (runners/functionrunner.h) +Shared ref-counted struct representing one action through its lifecycle. + +**Key fields:** +- `ActionLsn` — global unique sequence number +- `QueueId` — 0 for implicit/unqueued actions +- `Worker` — descriptor + content hash +- `ActionObj` — CbObject with the action spec +- `CpuUsagePercent` / `CpuSeconds` — atomics updated by monitor thread +- `RetryCount` — atomic int tracking how many times the action has been rescheduled +- `Timestamps[State::_Count]` — timestamp of each state transition + +**State machine (forward-only under normal flow, atomic):** +``` +New → Pending → Submitting → Running → Completed + → Failed + → Abandoned + → Cancelled +``` +`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +### `LocalProcessRunner` (runners/localrunner.h) +Base for all local execution. Platform runners subclass this and override: +- `SubmitAction()` — fork/exec the worker process +- `SweepRunningActions()` — poll for process exit (waitpid / WaitForSingleObject) +- `CancelRunningActions()` — signal all processes during shutdown +- `SampleProcessCpu(RunningAction&)` — read platform CPU usage (no-op default) + +**Infrastructure owned by LocalProcessRunner:** +- Monitor thread — calls `SweepRunningActions()` then `SampleRunningProcessCpu()` in a loop +- `m_RunningMap` — `RwLock`-guarded map of `Lsn → RunningAction` +- `DeferredDirectoryDeleter` — sandbox directories are queued for async deletion +- `PrepareActionSubmission()` — shared preamble (capacity check, sandbox creation, worker manifesting, input decompression) +- `ProcessCompletedActions()` — shared post-processing (gather outputs, set state, enqueue deletion) + +**CPU sampling:** `SampleRunningProcessCpu()` iterates `m_RunningMap` under shared lock, calls `SampleProcessCpu()` per entry, throttled to every 5 seconds per action. Platform implementations: +- Linux: `/proc/{pid}/stat` utime+stime jiffies ÷ `_SC_CLK_TCK` +- Windows: `GetProcessTimes()` in 100ns intervals ÷ 10,000,000 +- macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 + +### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) +Abstract base for runners. `RunnerGroup` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. + +### `HttpComputeService` (include/zencompute/httpcomputeservice.h) +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. + +## Action Lifecycle (End to End) + +1. **HTTP POST** → `HttpComputeService` ingests attachments, calls `EnqueueAction()` +2. **Enqueue** → creates `RunnerAction` (New → Pending), calls `PostUpdate()` +3. **PostUpdate** → appends to `m_UpdatedActions`, signals scheduler thread event +4. **Scheduler thread** → drains `m_UpdatedActions`, drives pending actions to runners +5. **Runner `SubmitAction()`** → Pending → Submitting (on runner's worker pool thread) +6. **Process launch** → Submitting → Running, added to `m_RunningMap` +7. **Monitor thread `SweepRunningActions()`** → detects exit, gathers outputs +8. **`ProcessCompletedActions()`** → Running → Completed/Failed/Abandoned, `PostUpdate()` +9. **Scheduler thread `HandleActionUpdates()`** — for Failed/Abandoned actions, checks retry limit; if retries remain, calls `ResetActionStateToPending()` which loops back to step 3. Otherwise moves to `m_ResultsMap`, records history, notifies queue. +10. **Client `GET /jobs/{lsn}`** → returns result from `m_ResultsMap`, schedules retirement + +### Action Rescheduling + +Actions that fail or are abandoned can be automatically retried or manually rescheduled via the API. + +**Automatic retry (scheduler path):** In `HandleActionUpdates()`, when a Failed or Abandoned state is detected, the scheduler checks `RetryCount < GetMaxRetriesForQueue(QueueId)`. If retries remain, the action is removed from active maps and `ResetActionStateToPending()` is called, which re-enters it into the scheduler pipeline. The action keeps its original LSN so clients can continue polling with the same identifier. + +**Manual retry (API path):** `POST /compute/jobs/{lsn}` calls `RescheduleAction()`, which finds the action in `m_ResultsMap`, validates state (must be Failed or Abandoned), checks the retry limit, reverses queue counters (moving the LSN from `FinishedLsns` back to `ActiveLsns`), removes from results, and calls `ResetActionStateToPending()`. Returns 200 with `{lsn, retry_count}` on success, 409 Conflict with `{error}` on failure. + +**Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. + +**Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. + +## Queue System + +Queues group actions from a single client session. A `QueueEntry` (internal) tracks: +- `State` — `std::atomic` lifecycle state (Active → Draining → Cancelled) +- `ActiveCount` — pending + running actions (atomic) +- `CompletedCount / FailedCount / AbandonedCount / CancelledCount` (atomics) +- `ActiveLsns` — for cancellation lookup (under `m_Lock`) +- `FinishedLsns` — moved here when actions complete +- `IdleSince` — used for 15-minute automatic expiry +- `Config` — CbObject set at creation; supports `max_retries` (int) to override the default retry limit + +**Queue state machine (`QueueState` enum):** +``` +Active → Draining → Cancelled + \ ↑ + ─────────────────────/ +``` +- **Active** — accepts new work, schedules pending work, finishes running work (initial state) +- **Draining** — rejects new work, finishes existing work (one-way via CAS from Active; cannot override Cancelled) +- **Cancelled** — rejects new work, actively cancels in-flight work (reachable from Active or Draining) + +Key operations: +- `CreateQueue(Tag)` → returns `QueueId` +- `EnqueueActionToQueue(QueueId, ...)` → action's `QueueId` field is set at creation +- `CancelQueue(QueueId)` → marks all active LSNs for cancellation +- `DrainQueue(QueueId)` → stops accepting new submissions; existing work finishes naturally (irreversible) +- `GetQueueCompleted(QueueId)` → CbWriter output of finished results +- Queue references in HTTP routes accept either a decimal ID or an Oid token (24-hex), resolved by `ResolveQueueRef()` + +## HTTP API + +All routes registered in `HttpComputeService` constructor. Prefix is configured externally (typically `/compute`). + +### Global endpoints +| Method | Path | Description | +|--------|------|-------------| +| POST | `abandon` | Transition session to Abandoned state (409 if invalid) | +| GET | `jobs/history` | Action history (last N, with timestamps per state) | +| GET | `jobs/running` | In-flight actions with CPU metrics | +| GET | `jobs/completed` | Actions with results available | +| GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{worker}` | Submit action for specific worker | +| POST | `jobs` | Submit action (worker resolved from descriptor) | +| GET | `workers` | List worker IDs | +| GET | `workers/all` | All workers with full descriptors | +| GET/POST | `workers/{worker}` | Get/register worker | + +### Queue-scoped endpoints +Queue ref is capture(1) in all `queues/{queueref}/...` routes. + +| Method | Path | Description | +|--------|------|-------------| +| GET | `queues` | List queue IDs | +| POST | `queues` | Create queue | +| GET/DELETE | `queues/{queueref}` | Status / delete | +| POST | `queues/{queueref}/drain` | Drain queue (irreversible; rejects new submissions) | +| GET | `queues/{queueref}/completed` | Queue's completed results | +| GET | `queues/{queueref}/history` | Queue's action history | +| GET | `queues/{queueref}/running` | Queue's running actions | +| POST | `queues/{queueref}/jobs` | Submit to queue | +| GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | + +Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. + +## Concurrency Model + +**Locking discipline:** When multiple locks must be held simultaneously, always acquire in this order to prevent deadlocks: +1. `m_ResultsLock` +2. `m_RunningLock` (comment in localrunner.h: "must be taken *after* m_ResultsLock") +3. `m_PendingLock` +4. `m_QueueLock` + +**Atomic fields** for counters and simple state: queue counts, `CpuUsagePercent`, `CpuSeconds`, `RetryCount`, `RunnerAction::m_ActionState`. + +**Update decoupling:** Runners call `PostUpdate(RunnerAction*)` rather than directly mutating service state. The scheduler thread batches and deduplicates updates. + +**Thread ownership:** +- Scheduler thread — drives state transitions, owns `m_PendingActions` +- Monitor thread (per runner) — polls process completion, owns `m_RunningMap` via shared lock +- Worker pool threads — async submission, brief `SubmitAction()` calls +- HTTP threads — read-only access to results, queue status + +## Sandbox Layout + +Each action gets a unique numbered directory under `m_SandboxPath`: +``` +scratch/{counter}/ + worker/ ← worker binaries (or bind-mounted on Linux) + inputs/ ← decompressed action inputs + outputs/ ← written by worker process +``` + +On Linux with sandboxing enabled, the process runs in a pivot-rooted namespace with `/usr`, `/lib`, `/etc`, `/worker` bind-mounted read-only and a tmpfs `/dev`. + +## Adding a New HTTP Endpoint + +1. Register the route in the `HttpComputeService` constructor in `httpcomputeservice.cpp` +2. If the handler is shared between top-level and a `queues/{queueref}/...` variant, extract it as a private helper method declared in `httpcomputeservice.h` +3. Queue-scoped routes validate the queue ref with `ResolveQueueRef(HttpReq, Req.GetCapture(1))` which writes an error response and returns 0 on failure +4. Use `CbObjectWriter` for response bodies; emit via `HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save())` +5. Conditional fields (e.g., optional CPU metrics): emit inside `if (value > 0.0f)` / `if (value >= 0.0f)` guards to omit absent values rather than emitting sentinel values + +## Adding a New Runner Platform + +1. Subclass `LocalProcessRunner`, add `h`/`cpp` files in `runners/` +2. Override `SubmitAction()`, `SweepRunningActions()`, `CancelRunningActions()`, and optionally `CancelAction(int)` and `SampleProcessCpu(RunningAction&)` +3. `SampleProcessCpu()` must update both `Running.Action->CpuSeconds` (unconditionally from the absolute OS value) and `Running.Action->CpuUsagePercent` (delta-based, only after second sample) +4. `ProcessHandle` convention: store pid as `reinterpret_cast(static_cast(pid))` for consistency with the base class +5. Register in `ComputeServiceSession::AddLocalRunner()` in `computeservice.cpp` diff --git a/src/zencompute/actionrecorder.cpp b/src/zencompute/actionrecorder.cpp deleted file mode 100644 index 04c4b5141..000000000 --- a/src/zencompute/actionrecorder.cpp +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "actionrecorder.h" - -#include "functionrunner.h" - -#include -#include -#include -#include -#include -#include - -#if ZEN_PLATFORM_WINDOWS -# include -# define ZEN_CONCRT_AVAILABLE 1 -#else -# define ZEN_CONCRT_AVAILABLE 0 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -namespace zen::compute { - -using namespace std::literals; - -////////////////////////////////////////////////////////////////////////// - -RecordingFileWriter::RecordingFileWriter() -{ -} - -RecordingFileWriter::~RecordingFileWriter() -{ - Close(); -} - -void -RecordingFileWriter::Open(std::filesystem::path FilePath) -{ - using namespace std::literals; - - m_File.Open(FilePath, BasicFile::Mode::kTruncate); - m_File.Write("----DDC2----DATA", 16, 0); - m_FileOffset = 16; - - std::filesystem::path TocPath = FilePath.replace_extension(".ztoc"); - m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate); - - m_TocWriter << "version"sv << 1; - m_TocWriter.BeginArray("toc"sv); -} - -void -RecordingFileWriter::Close() -{ - m_TocWriter.EndArray(); - CbObject Toc = m_TocWriter.Save(); - - std::error_code Ec; - m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec); -} - -void -RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash) -{ - RwLock::ExclusiveLockScope _(m_FileLock); - - MemoryView ObjectView = Object.GetBuffer().GetView(); - - std::error_code Ec; - m_File.Write(ObjectView, m_FileOffset, Ec); - - if (Ec) - { - throw std::system_error(Ec, "failed writing to archive"); - } - - m_TocWriter.BeginArray(); - m_TocWriter.AddHash(ObjectHash); - m_TocWriter.AddInteger(m_FileOffset); - m_TocWriter.AddInteger(gsl::narrow(ObjectView.GetSize())); - m_TocWriter.EndArray(); - - m_FileOffset += ObjectView.GetSize(); -} - -////////////////////////////////////////////////////////////////////////// - -ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath) -: m_ChunkResolver(InChunkResolver) -, m_RecordingLogDir(RecordingLogPath) -{ - std::error_code Ec; - CreateDirectories(m_RecordingLogDir, Ec); - - if (Ec) - { - ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message()); - } - - CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec); - - if (Ec) - { - ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message()); - } - - m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat"); - m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat"); - - CidStoreConfiguration CidConfig; - CidConfig.RootDirectory = m_RecordingLogDir / "cid"; - CidConfig.HugeValueThreshold = 128 * 1024 * 1024; - - m_CidStore.Initialize(CidConfig); -} - -ActionRecorder::~ActionRecorder() -{ - Shutdown(); -} - -void -ActionRecorder::Shutdown() -{ - m_CidStore.Flush(); -} - -void -ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage) -{ - const IoHash WorkerId = WorkerPackage.GetObjectHash(); - - m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId); - - std::unordered_set AddedChunks; - uint64_t AddedBytes = 0; - - // First add all attachments from the worker package itself - - for (const CbAttachment& Attachment : WorkerPackage.GetAttachments()) - { - CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer(); - - const IoHash ChunkHash = Buffer.DecodeRawHash(); - - CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly); - - AddedChunks.insert(ChunkHash); - - if (Result.New) - { - AddedBytes += Data.GetSize(); - } - } - - // Not all attachments will be present in the worker package, so we need to add - // all referenced chunks to ensure that the recording is self-contained and not - // referencing data in the main CID store - - CbObject WorkerDescriptor = WorkerPackage.GetObject(); - - WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) { - const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - - if (!AddedChunks.contains(AttachmentCid)) - { - IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid); - - if (AttachmentData) - { - CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly); - - if (Result.New) - { - AddedBytes += AttachmentData.GetSize(); - } - } - else - { - ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId); - } - - AddedChunks.insert(AttachmentCid); - } - }); - - ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes); -} - -bool -ActionRecorder::RecordAction(Ref Action) -{ - bool AllGood = true; - - Action->ActionObj.IterateAttachments([&](CbFieldView Field) { - IoHash AttachData = Field.AsHash(); - IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData); - - if (ChunkData) - { - if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary) - { - IoHash DecompressedHash; - uint64_t RawSize = 0; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize); - - OodleCompressor Compressor; - OodleCompressionLevel CompressionLevel; - uint64_t BlockSize = 0; - if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize)) - { - if (Compressor == OodleCompressor::NotSet) - { - CompositeBuffer Decompressed = Compressed.DecompressToComposite(); - CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed), - OodleCompressor::Mermaid, - OodleCompressionLevel::Fast, - BlockSize); - - ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer(); - } - } - } - - const uint64_t ChunkSize = ChunkData.GetSize(); - - m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly); - ++m_ChunkCounter; - m_ChunkBytesCounter.fetch_add(ChunkSize); - } - else - { - AllGood = false; - - ZEN_WARN("could not resolve chunk {}", AttachData); - } - }); - - if (AllGood) - { - m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId); - ++m_ActionsCounter; - - return true; - } - else - { - return false; - } -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/actionrecorder.h b/src/zencompute/actionrecorder.h deleted file mode 100644 index 9cc2b44a2..000000000 --- a/src/zencompute/actionrecorder.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace zen { -class CbObject; -class CbPackage; -struct IoHash; -} // namespace zen - -#if ZEN_WITH_COMPUTE_SERVICES - -namespace zen::compute { - -////////////////////////////////////////////////////////////////////////// - -struct RecordingFileWriter -{ - RecordingFileWriter(RecordingFileWriter&&) = delete; - RecordingFileWriter& operator=(RecordingFileWriter&&) = delete; - - RwLock m_FileLock; - BasicFile m_File; - uint64_t m_FileOffset = 0; - CbObjectWriter m_TocWriter; - BasicFile m_TocFile; - - RecordingFileWriter(); - ~RecordingFileWriter(); - - void Open(std::filesystem::path FilePath); - void Close(); - void AppendObject(const CbObject& Object, const IoHash& ObjectHash); -}; - -////////////////////////////////////////////////////////////////////////// - -/** - * Recording "runner" implementation - * - * This class writes out all actions and their attachments to a recording directory - * in a format that can be read back by the RecordingReader. - * - * The contents of the recording directory will be self-contained, with all referenced - * attachments stored in the recording directory itself, so that the recording can be - * moved or shared without needing to maintain references to the main CID store. - * - */ - -class ActionRecorder -{ -public: - ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath); - ~ActionRecorder(); - - ActionRecorder(const ActionRecorder&) = delete; - ActionRecorder& operator=(const ActionRecorder&) = delete; - - void Shutdown(); - void RegisterWorker(const CbPackage& WorkerPackage); - bool RecordAction(Ref Action); - -private: - ChunkResolver& m_ChunkResolver; - std::filesystem::path m_RecordingLogDir; - - RecordingFileWriter m_WorkersFile; - RecordingFileWriter m_ActionsFile; - GcManager m_Gc; - CidStore m_CidStore{m_Gc}; - std::atomic m_ChunkCounter{0}; - std::atomic m_ChunkBytesCounter{0}; - std::atomic m_ActionsCounter{0}; -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp new file mode 100644 index 000000000..b3b3210d9 --- /dev/null +++ b/src/zencompute/cloudmetadata.cpp @@ -0,0 +1,1010 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +// All major cloud providers expose instance metadata at this link-local address. +// It is only routable from within a cloud VM; on bare-metal the TCP connect will +// fail, which is how we distinguish cloud from non-cloud environments. +static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; + +// Short connect timeout so that detection on non-cloud machines is fast. The IMDS +// is a local service on the hypervisor so 200ms is generous for actual cloud VMs. +static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) +{ +} + +CloudMetadata::CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint) +: m_Log(logging::Get("cloud")) +, m_DataDir(std::move(DataDir)) +, m_ImdsEndpoint(std::move(ImdsEndpoint)) +{ + ZEN_TRACE_CPU("CloudMetadata::CloudMetadata"); + + std::error_code Ec; + std::filesystem::create_directories(m_DataDir, Ec); + + DetectProvider(); + + if (m_Info.Provider != CloudProvider::None) + { + StartTerminationMonitor(); + } +} + +CloudMetadata::~CloudMetadata() +{ + ZEN_TRACE_CPU("CloudMetadata::~CloudMetadata"); + m_MonitorEnabled = false; + m_MonitorEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +CloudProvider +CloudMetadata::GetProvider() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); +} + +CloudInstanceInfo +CloudMetadata::GetInstanceInfo() const +{ + return m_InfoLock.WithSharedLock([&] { return m_Info; }); +} + +bool +CloudMetadata::IsTerminationPending() const +{ + return m_TerminationPending.load(std::memory_order_relaxed); +} + +std::string +CloudMetadata::GetTerminationReason() const +{ + return m_ReasonLock.WithSharedLock([&] { return m_TerminationReason; }); +} + +void +CloudMetadata::Describe(CbWriter& Writer) const +{ + ZEN_TRACE_CPU("CloudMetadata::Describe"); + CloudInstanceInfo Info = GetInstanceInfo(); + + if (Info.Provider == CloudProvider::None) + { + return; + } + + Writer.BeginObject("cloud"); + Writer << "provider" << ToString(Info.Provider); + Writer << "instance_id" << Info.InstanceId; + Writer << "availability_zone" << Info.AvailabilityZone; + Writer << "is_spot" << Info.IsSpot; + Writer << "is_autoscaling" << Info.IsAutoscaling; + Writer << "termination_pending" << IsTerminationPending(); + + if (IsTerminationPending()) + { + Writer << "termination_reason" << GetTerminationReason(); + } + + Writer.EndObject(); +} + +void +CloudMetadata::DetectProvider() +{ + ZEN_TRACE_CPU("CloudMetadata::DetectProvider"); + + if (TryDetectAWS()) + { + return; + } + + if (TryDetectAzure()) + { + return; + } + + if (TryDetectGCP()) + { + return; + } + + ZEN_DEBUG("no cloud provider detected"); +} + +// AWS detection uses IMDSv2 which requires a session token obtained via PUT before +// any GET requests are allowed. This is more secure than IMDSv1 (which allowed +// unauthenticated GETs) and is the default on modern EC2 instances. The token has +// a 300-second TTL and is reused for termination polling. +bool +CloudMetadata::TryDetectAWS() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAWS"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAWS"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping AWS detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing AWS IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + // IMDSv2: acquire session token. The TTL header is mandatory; we request + // 300s which is sufficient for the detection phase. The token is also + // stored in m_AwsToken for reuse by the termination polling thread. + HttpClient::KeyValueMap TokenHeaders(std::pair{"X-aws-ec2-metadata-token-ttl-seconds", "300"}); + HttpClient::Response TokenResponse = ImdsClient.Put("/latest/api/token", IoBuffer{}, TokenHeaders); + + if (!TokenResponse.IsSuccess()) + { + ZEN_DEBUG("AWS IMDS token request failed ({}), not on AWS", static_cast(TokenResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_AwsToken = std::string(TokenResponse.AsText()); + + HttpClient::KeyValueMap AuthHeaders(std::pair{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response IdResponse = ImdsClient.Get("/latest/meta-data/instance-id", AuthHeaders); + if (IdResponse.IsSuccess()) + { + m_Info.InstanceId = std::string(IdResponse.AsText()); + } + + HttpClient::Response AzResponse = ImdsClient.Get("/latest/meta-data/placement/availability-zone", AuthHeaders); + if (AzResponse.IsSuccess()) + { + m_Info.AvailabilityZone = std::string(AzResponse.AsText()); + } + + // "spot" vs "on-demand" — determines whether the instance can be + // reclaimed by AWS with a 2-minute warning + HttpClient::Response LifecycleResponse = ImdsClient.Get("/latest/meta-data/instance-life-cycle", AuthHeaders); + if (LifecycleResponse.IsSuccess()) + { + m_Info.IsSpot = (LifecycleResponse.AsText() == "spot"); + } + + // This endpoint only exists on instances managed by an Auto Scaling + // Group. A successful response (regardless of value) means autoscaling. + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + m_Info.IsAutoscaling = true; + } + + m_Info.Provider = CloudProvider::AWS; + + ZEN_INFO("detected AWS instance: id={}, az={}, spot={}, autoscaling={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("AWS IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Azure IMDS returns a single JSON document for the entire instance metadata, +// unlike AWS and GCP which use separate plain-text endpoints per field. The +// "Metadata: true" header is required; requests without it are rejected. +// The api-version parameter is mandatory and pins the response schema. +bool +CloudMetadata::TryDetectAzure() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectAzure"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotAzure"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping Azure detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing Azure IMDS"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair{"Metadata", "true"}, + }); + + HttpClient::Response InstanceResponse = ImdsClient.Get("/metadata/instance?api-version=2021-02-01", MetadataHeaders); + + if (!InstanceResponse.IsSuccess()) + { + ZEN_DEBUG("Azure IMDS request failed ({}), not on Azure", static_cast(InstanceResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(InstanceResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_DEBUG("Azure IMDS returned invalid JSON: {}", JsonError); + WriteSentinelFile(SentinelPath); + return false; + } + + const json11::Json& Compute = Json["compute"]; + + m_Info.InstanceId = Compute["vmId"].string_value(); + m_Info.AvailabilityZone = Compute["location"].string_value(); + + // Azure spot VMs have priority "Spot"; regular VMs have "Regular" + std::string Priority = Compute["priority"].string_value(); + m_Info.IsSpot = (Priority == "Spot"); + + // Check if part of a VMSS (Virtual Machine Scale Set) — indicates autoscaling + std::string VmssName = Compute["vmScaleSetName"].string_value(); + m_Info.IsAutoscaling = !VmssName.empty(); + + m_Info.Provider = CloudProvider::Azure; + + ZEN_INFO("detected Azure instance: id={}, location={}, spot={}, vmss={}", + m_Info.InstanceId, + m_Info.AvailabilityZone, + m_Info.IsSpot, + m_Info.IsAutoscaling); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("Azure IMDS probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// GCP requires the "Metadata-Flavor: Google" header on all IMDS requests. +// Unlike AWS, there is no session token; the header itself is the auth mechanism +// (it prevents SSRF attacks since browsers won't send custom headers to the +// metadata endpoint). Each metadata field is fetched from a separate URL. +bool +CloudMetadata::TryDetectGCP() +{ + ZEN_TRACE_CPU("CloudMetadata::TryDetectGCP"); + + std::filesystem::path SentinelPath = m_DataDir / ".isNotGCP"; + + if (HasSentinelFile(SentinelPath)) + { + ZEN_DEBUG("skipping GCP detection - negative cache hit"); + return false; + } + + ZEN_DEBUG("probing GCP metadata service"); + + try + { + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{1000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair{"Metadata-Flavor", "Google"}); + + // Fetch instance ID + HttpClient::Response IdResponse = ImdsClient.Get("/computeMetadata/v1/instance/id", MetadataHeaders); + + if (!IdResponse.IsSuccess()) + { + ZEN_DEBUG("GCP metadata request failed ({}), not on GCP", static_cast(IdResponse.StatusCode)); + WriteSentinelFile(SentinelPath); + return false; + } + + m_Info.InstanceId = std::string(IdResponse.AsText()); + + // GCP returns the fully-qualified zone path "projects//zones/". + // Strip the prefix to get just the zone name (e.g. "us-central1-a"). + HttpClient::Response ZoneResponse = ImdsClient.Get("/computeMetadata/v1/instance/zone", MetadataHeaders); + if (ZoneResponse.IsSuccess()) + { + std::string_view Zone = ZoneResponse.AsText(); + if (auto Pos = Zone.rfind('/'); Pos != std::string_view::npos) + { + Zone = Zone.substr(Pos + 1); + } + m_Info.AvailabilityZone = std::string(Zone); + } + + // Check for preemptible/spot (scheduling/preemptible returns "TRUE" or "FALSE") + HttpClient::Response PreemptibleResponse = ImdsClient.Get("/computeMetadata/v1/instance/scheduling/preemptible", MetadataHeaders); + if (PreemptibleResponse.IsSuccess()) + { + m_Info.IsSpot = (PreemptibleResponse.AsText() == "TRUE"); + } + + // Check for maintenance event + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + m_TerminationPending = true; + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + } + } + + m_Info.Provider = CloudProvider::GCP; + + ZEN_INFO("detected GCP instance: id={}, az={}, spot={}", m_Info.InstanceId, m_Info.AvailabilityZone, m_Info.IsSpot); + + return true; + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("GCP metadata probe failed: {}", Ex.what()); + WriteSentinelFile(SentinelPath); + return false; + } +} + +// Sentinel files are empty marker files whose mere existence signals that a +// previous detection attempt for a given provider failed. This avoids paying +// the connect-timeout cost on every startup for providers that are known to +// be absent. The files persist across process restarts; delete them manually +// (or remove the DataDir) to force re-detection. +void +CloudMetadata::WriteSentinelFile(const std::filesystem::path& Path) +{ + try + { + BasicFile File; + File.Open(Path, BasicFile::Mode::kTruncate); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to write sentinel file '{}': {}", Path.string(), Ex.what()); + } +} + +bool +CloudMetadata::HasSentinelFile(const std::filesystem::path& Path) const +{ + return zen::IsFile(Path); +} + +void +CloudMetadata::ClearSentinelFiles() +{ + std::error_code Ec; + std::filesystem::remove(m_DataDir / ".isNotAWS", Ec); + std::filesystem::remove(m_DataDir / ".isNotAzure", Ec); + std::filesystem::remove(m_DataDir / ".isNotGCP", Ec); +} + +void +CloudMetadata::StartTerminationMonitor() +{ + ZEN_INFO("starting cloud termination monitor for {} instance {}", ToString(m_Info.Provider), m_Info.InstanceId); + + m_MonitorThread = std::thread{&CloudMetadata::TerminationMonitorThread, this}; +} + +void +CloudMetadata::TerminationMonitorThread() +{ + SetCurrentThreadName("cloud_term_mon"); + + // Poll every 5 seconds. The Event is used as an interruptible sleep so + // that the destructor can wake us up immediately for a clean shutdown. + while (m_MonitorEnabled) + { + m_MonitorEvent.Wait(5000); + m_MonitorEvent.Reset(); + + if (!m_MonitorEnabled) + { + return; + } + + PollTermination(); + } +} + +void +CloudMetadata::PollTermination() +{ + try + { + CloudProvider Provider = m_InfoLock.WithSharedLock([&] { return m_Info.Provider; }); + + if (Provider == CloudProvider::AWS) + { + PollAWSTermination(); + } + else if (Provider == CloudProvider::Azure) + { + PollAzureTermination(); + } + else if (Provider == CloudProvider::GCP) + { + PollGCPTermination(); + } + } + catch (const std::exception& Ex) + { + ZEN_DEBUG("termination poll error: {}", Ex.what()); + } +} + +// AWS termination signals: +// - /spot/instance-action: returns 200 with a JSON body ~2 minutes before +// a spot instance is reclaimed. Returns 404 when no action is pending. +// - /autoscaling/target-lifecycle-state: returns the ASG lifecycle state. +// "InService" is normal; anything else (e.g. "Terminated:Wait") means +// the instance is being cycled out. +void +CloudMetadata::PollAWSTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAWSTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-aws", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap AuthHeaders(std::pair{"X-aws-ec2-metadata-token", m_AwsToken}); + + HttpClient::Response SpotResponse = ImdsClient.Get("/latest/meta-data/spot/instance-action", AuthHeaders); + if (SpotResponse.IsSuccess()) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS spot interruption: {}", SpotResponse.AsText()); }); + ZEN_WARN("AWS spot interruption detected: {}", SpotResponse.AsText()); + } + return; + } + + HttpClient::Response AutoscaleResponse = ImdsClient.Get("/latest/meta-data/autoscaling/target-lifecycle-state", AuthHeaders); + if (AutoscaleResponse.IsSuccess()) + { + std::string_view State = AutoscaleResponse.AsText(); + if (State.find("InService") == std::string_view::npos) + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("AWS autoscaling lifecycle: {}", State); }); + ZEN_WARN("AWS autoscaling termination detected: {}", State); + } + } + } +} + +// Azure Scheduled Events API returns a JSON array of upcoming platform events. +// We care about "Preempt" (spot eviction), "Terminate", and "Reboot" events. +// Other event types like "Freeze" (live migration) are non-destructive and +// ignored. The Events array is empty when nothing is pending. +void +CloudMetadata::PollAzureTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollAzureTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-azure", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders({ + std::pair{"Metadata", "true"}, + }); + + HttpClient::Response EventsResponse = ImdsClient.Get("/metadata/scheduledevents?api-version=2020-07-01", MetadataHeaders); + + if (!EventsResponse.IsSuccess()) + { + return; + } + + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(EventsResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + return; + } + + const json11::Json::array& Events = Json["Events"].array_items(); + for (const auto& Evt : Events) + { + std::string EventType = Evt["EventType"].string_value(); + if (EventType == "Preempt" || EventType == "Terminate" || EventType == "Reboot") + { + if (!m_TerminationPending.exchange(true)) + { + std::string EventStatus = Evt["EventStatus"].string_value(); + m_ReasonLock.WithExclusiveLock( + [&] { m_TerminationReason = fmt::format("Azure scheduled event: {} ({})", EventType, EventStatus); }); + ZEN_WARN("Azure termination event detected: {} ({})", EventType, EventStatus); + } + return; + } + } +} + +// GCP maintenance-event returns "NONE" when nothing is pending, and a +// descriptive string like "TERMINATE_ON_HOST_MAINTENANCE" when the VM is +// about to be live-migrated or terminated. Preemptible/spot VMs get a +// 30-second warning before termination. +void +CloudMetadata::PollGCPTermination() +{ + ZEN_TRACE_CPU("CloudMetadata::PollGCPTermination"); + + HttpClient ImdsClient(m_ImdsEndpoint, + {.LogCategory = "cloud-gcp", .ConnectTimeout = kImdsTimeout, .Timeout = std::chrono::milliseconds{2000}}); + + HttpClient::KeyValueMap MetadataHeaders(std::pair{"Metadata-Flavor", "Google"}); + + HttpClient::Response MaintenanceResponse = ImdsClient.Get("/computeMetadata/v1/instance/maintenance-event", MetadataHeaders); + if (MaintenanceResponse.IsSuccess()) + { + std::string_view Event = MaintenanceResponse.AsText(); + if (!Event.empty() && Event != "NONE") + { + if (!m_TerminationPending.exchange(true)) + { + m_ReasonLock.WithExclusiveLock([&] { m_TerminationReason = fmt::format("GCP maintenance event: {}", Event); }); + ZEN_WARN("GCP maintenance event detected: {}", Event); + } + } + } +} + +} // namespace zen::compute + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +# include + +# include +# include +# include +# include + +# include +# include + +namespace zen::compute { + +// --------------------------------------------------------------------------- +// Test helper — spins up a local ASIO HTTP server hosting a MockImdsService +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7575, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + std::filesystem::path DataDir() const { return m_TmpDir->Path() / "cloud"; } + + std::unique_ptr CreateCloud() { return std::make_unique(DataDir(), Endpoint()); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional m_TmpDir; + Ref m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.aws") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + SUBCASE("detection basics") + { + Imds.Mock.Aws.InstanceId = "i-abc123"; + Imds.Mock.Aws.AvailabilityZone = "us-west-2b"; + Imds.Mock.Aws.LifeCycle = "on-demand"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "i-abc123"); + CHECK(Info.AvailabilityZone == "us-west-2b"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("autoscaling instance") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("spot termination") + { + Imds.Mock.Aws.LifeCycle = "spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate a spot interruption notice appearing + Imds.Mock.Aws.SpotAction = R"({"action":"terminate","time":"2025-01-01T00:00:00Z"})"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("spot interruption") != std::string::npos); + } + + SUBCASE("autoscaling termination") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + // Simulate ASG cycling the instance out + Imds.Mock.Aws.AutoscalingState = "Terminated:Wait"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("autoscaling") != std::string::npos); + } + + SUBCASE("no termination when InService") + { + Imds.Mock.Aws.AutoscalingState = "InService"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.azure") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::Azure; + + SUBCASE("detection basics") + { + Imds.Mock.Azure.VmId = "vm-test-1234"; + Imds.Mock.Azure.Location = "westeurope"; + Imds.Mock.Azure.Priority = "Regular"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::Azure); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "vm-test-1234"); + CHECK(Info.AvailabilityZone == "westeurope"); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("spot instance") + { + Imds.Mock.Azure.Priority = "Spot"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("vmss instance") + { + Imds.Mock.Azure.VmScaleSetName = "my-vmss"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsAutoscaling == true); + } + + SUBCASE("preempt termination") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Preempt"; + Imds.Mock.Azure.ScheduledEventStatus = "Scheduled"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Preempt") != std::string::npos); + } + + SUBCASE("terminate event") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Azure.ScheduledEventType = "Terminate"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("Terminate") != std::string::npos); + } + + SUBCASE("no termination when events empty") + { + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.gcp") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::GCP; + + SUBCASE("detection basics") + { + Imds.Mock.Gcp.InstanceId = "9876543210"; + Imds.Mock.Gcp.Zone = "projects/123/zones/europe-west1-b"; + Imds.Mock.Gcp.Preemptible = "FALSE"; + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::GCP); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId == "9876543210"); + CHECK(Info.AvailabilityZone == "europe-west1-b"); // zone prefix stripped + CHECK(Info.IsSpot == false); + CHECK(Cloud->IsTerminationPending() == false); + } + + SUBCASE("preemptible instance") + { + Imds.Mock.Gcp.Preemptible = "TRUE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.IsSpot == true); + } + + SUBCASE("maintenance event during detection") + { + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + // GCP sets termination pending immediately during detection if a + // maintenance event is active + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("maintenance event during polling") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + CHECK(Cloud->IsTerminationPending() == false); + + Imds.Mock.Gcp.MaintenanceEvent = "TERMINATE_ON_HOST_MAINTENANCE"; + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == true); + CHECK(Cloud->GetTerminationReason().find("maintenance") != std::string::npos); + } + + SUBCASE("no termination when NONE") + { + Imds.Mock.Gcp.MaintenanceEvent = "NONE"; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + Cloud->PollTermination(); + + CHECK(Cloud->IsTerminationPending() == false); + } +} + +// --------------------------------------------------------------------------- +// No provider +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.no_provider") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + + CloudInstanceInfo Info = Cloud->GetInstanceInfo(); + CHECK(Info.InstanceId.empty()); + CHECK(Info.AvailabilityZone.empty()); + CHECK(Info.IsSpot == false); + CHECK(Info.IsAutoscaling == false); + CHECK(Cloud->IsTerminationPending() == false); +} + +// --------------------------------------------------------------------------- +// Sentinel file management +// --------------------------------------------------------------------------- + +TEST_CASE("cloudmetadata.sentinel_files") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = CloudProvider::None; + Imds.Start(); + + auto DataDir = Imds.DataDir(); + + SUBCASE("sentinels are written on failed detection") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::None); + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("ClearSentinelFiles removes sentinels") + { + auto Cloud = Imds.CreateCloud(); + + CHECK(zen::IsFile(DataDir / ".isNotAWS")); + CHECK(zen::IsFile(DataDir / ".isNotAzure")); + CHECK(zen::IsFile(DataDir / ".isNotGCP")); + + Cloud->ClearSentinelFiles(); + + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } + + SUBCASE("only failed providers get sentinels") + { + // Switch to AWS — Azure and GCP never probed, so no sentinels for them + Imds.Mock.ActiveProvider = CloudProvider::AWS; + + auto Cloud = Imds.CreateCloud(); + + CHECK(Cloud->GetProvider() == CloudProvider::AWS); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAWS")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotAzure")); + CHECK_FALSE(zen::IsFile(DataDir / ".isNotGCP")); + } +} + +void +cloudmetadata_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp new file mode 100644 index 000000000..838d741b6 --- /dev/null +++ b/src/zencompute/computeservice.cpp @@ -0,0 +1,2236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" +# include "recording/actionrecorder.h" +# include "runners/localrunner.h" +# include "runners/remotehttprunner.h" +# if ZEN_PLATFORM_LINUX +# include "runners/linuxrunner.h" +# elif ZEN_PLATFORM_WINDOWS +# include "runners/windowsrunner.h" +# elif ZEN_PLATFORM_MAC +# include "runners/macrunner.h" +# endif + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +const char* +ToString(compute::ComputeServiceSession::SessionState State) +{ + using enum compute::ComputeServiceSession::SessionState; + switch (State) + { + case Created: + return "Created"; + case Ready: + return "Ready"; + case Draining: + return "Draining"; + case Paused: + return "Paused"; + case Abandoned: + return "Abandoned"; + case Sunset: + return "Sunset"; + } + return "Unknown"; +} + +const char* +ToString(compute::ComputeServiceSession::QueueState State) +{ + using enum compute::ComputeServiceSession::QueueState; + switch (State) + { + case Active: + return "active"; + case Draining: + return "draining"; + case Cancelled: + return "cancelled"; + } + return "unknown"; +} + +} // namespace zen + +namespace zen::compute { + +using SessionState = ComputeServiceSession::SessionState; + +static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast(RunnerAction::State::_Count)); + +////////////////////////////////////////////////////////////////////////// + +struct ComputeServiceSession::Impl +{ + ComputeServiceSession* m_ComputeServiceSession; + ChunkResolver& m_ChunkResolver; + LoggerRef m_Log{logging::Get("compute")}; + + Impl(ComputeServiceSession* InComputeServiceSession, ChunkResolver& InChunkResolver) + : m_ComputeServiceSession(InComputeServiceSession) + , m_ChunkResolver(InChunkResolver) + , m_LocalSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + , m_RemoteSubmitPool(GetLargeWorkerPool(EWorkloadType::Burst)) + { + // Create a non-expiring, non-deletable implicit queue for legacy endpoints + auto Result = CreateQueue("implicit"sv, {}, {}); + m_ImplicitQueueId = Result.QueueId; + m_QueueLock.WithSharedLock([&] { m_Queues[m_ImplicitQueueId]->Implicit = true; }); + + m_SchedulingThread = std::thread{&Impl::SchedulerThreadFunction, this}; + } + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + bool RequestStateTransition(SessionState NewState); + void AbandonAllActions(); + + LoggerRef Log() { return m_Log; } + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + std::string m_OrchestratorEndpoint; + std::filesystem::path m_OrchestratorBasePath; + Stopwatch m_OrchestratorQueryTimer; + std::unordered_set m_KnownWorkerUris; + + void UpdateCoordinatorState(); + + // Worker registration and discovery + + struct FunctionDefinition + { + std::string FunctionName; + Guid FunctionVersion; + Guid BuildSystemVersion; + IoHash WorkerId; + }; + + void RegisterWorker(CbPackage Worker); + WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + + // Action scheduling and tracking + + std::atomic m_SessionState{SessionState::Created}; + std::atomic m_ActionsCounter = 0; // sequence number + metrics::Meter m_ArrivalRate; + + RwLock m_PendingLock; + std::map> m_PendingActions; + + RwLock m_RunningLock; + std::unordered_map> m_RunningMap; + + RwLock m_ResultsLock; + std::unordered_map> m_ResultsMap; + metrics::Meter m_ResultRate; + std::atomic m_RetiredCount{0}; + + EnqueueResult EnqueueAction(int QueueId, CbObject ActionObject, int Priority); + EnqueueResult EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority); + + void GetCompleted(CbWriter& Cbo); + + HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + std::thread m_SchedulingThread; + std::atomic m_SchedulingThreadEnabled{true}; + Event m_SchedulingThreadEvent; + + void SchedulerThreadFunction(); + void SchedulePendingActions(); + + // Workers + + RwLock m_WorkerLock; + std::unordered_map m_WorkerMap; + std::vector m_FunctionList; + std::vector GetKnownWorkerIds(); + void SyncWorkersToRunner(FunctionRunner& Runner); + + // Runners + + DeferredDirectoryDeleter m_DeferredDeleter; + WorkerThreadPool& m_LocalSubmitPool; + WorkerThreadPool& m_RemoteSubmitPool; + RunnerGroup m_LocalRunnerGroup; + RunnerGroup m_RemoteRunnerGroup; + + void ShutdownRunners(); + + // Recording + + void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); + void StopRecording(); + + std::unique_ptr m_Recorder; + + // History tracking + + RwLock m_ActionHistoryLock; + std::deque m_ActionHistory; + size_t m_HistoryLimit = 1000; + + // Queue tracking + + using QueueState = ComputeServiceSession::QueueState; + + struct QueueEntry : RefCounted + { + int QueueId; + bool Implicit{false}; + std::atomic State{QueueState::Active}; + std::atomic ActiveCount{0}; // pending + running + std::atomic CompletedCount{0}; // successfully completed + std::atomic FailedCount{0}; // failed + std::atomic AbandonedCount{0}; // abandoned + std::atomic CancelledCount{0}; // cancelled + std::atomic IdleSince{0}; // hifreq tick when queue became idle; 0 = has active work + + RwLock m_Lock; + std::unordered_set ActiveLsns; // for cancellation lookup + std::unordered_set FinishedLsns; // completed/failed/cancelled LSNs + + std::string Tag; + CbObject Metadata; + CbObject Config; + }; + + int m_ImplicitQueueId{0}; + std::atomic m_QueueCounter{0}; + RwLock m_QueueLock; + std::unordered_map> m_Queues; + + Ref FindQueue(int QueueId) + { + Ref Queue; + m_QueueLock.WithSharedLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + Queue = It->second; + } + }); + return Queue; + } + + ComputeServiceSession::CreateQueueResult CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config); + std::vector GetQueueIds(); + ComputeServiceSession::QueueStatus GetQueueStatus(int QueueId); + CbObject GetQueueMetadata(int QueueId); + CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DeleteQueue(int QueueId); + void DrainQueue(int QueueId); + ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); + void ExpireCompletedQueues(); + + Stopwatch m_QueueExpiryTimer; + + std::vector GetRunningActions(); + std::vector GetActionHistory(int Limit); + std::vector GetQueueHistory(int QueueId, int Limit); + + // Action submission + + [[nodiscard]] size_t QueryCapacity(); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action); + [[nodiscard]] std::vector SubmitActions(const std::vector>& Actions); + [[nodiscard]] size_t GetSubmittedActionCount(); + + // Updates + + RwLock m_UpdatedActionsLock; + std::vector> m_UpdatedActions; + + void HandleActionUpdates(); + void PostUpdate(RunnerAction* Action); + + static constexpr int kDefaultMaxRetries = 3; + int GetMaxRetriesForQueue(int QueueId); + + ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + + ActionCounts GetActionCounts() + { + ActionCounts Counts; + Counts.Pending = (int)m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + Counts.Running = (int)m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + Counts.Completed = (int)m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }) + (int)m_RetiredCount.load(); + Counts.ActiveQueues = (int)m_QueueLock.WithSharedLock([&] { + size_t Count = 0; + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + ++Count; + } + } + return Count; + }); + return Counts; + } + + void EmitStats(CbObjectWriter& Cbo) + { + Cbo << "session_state"sv << ToString(m_SessionState.load(std::memory_order_relaxed)); + m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); + m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); + m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); + Cbo << "actions_submitted"sv << GetSubmittedActionCount(); + EmitSnapshot("actions_arrival"sv, m_ArrivalRate, Cbo); + EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); + } +}; + +bool +ComputeServiceSession::Impl::IsHealthy() +{ + return m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned; +} + +bool +ComputeServiceSession::Impl::RequestStateTransition(SessionState NewState) +{ + SessionState Current = m_SessionState.load(std::memory_order_relaxed); + + for (;;) + { + if (Current == NewState) + { + return true; + } + + // Validate the transition + bool Valid = false; + + switch (Current) + { + case SessionState::Created: + Valid = (NewState == SessionState::Ready); + break; + case SessionState::Ready: + Valid = (NewState == SessionState::Draining); + break; + case SessionState::Draining: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Paused); + break; + case SessionState::Paused: + Valid = (NewState == SessionState::Ready || NewState == SessionState::Sunset); + break; + case SessionState::Abandoned: + Valid = (NewState == SessionState::Sunset); + break; + case SessionState::Sunset: + Valid = false; + break; + } + + // Allow jumping directly to Abandoned or Sunset from any non-terminal state + if (NewState == SessionState::Abandoned && Current < SessionState::Abandoned) + { + Valid = true; + } + if (NewState == SessionState::Sunset && Current != SessionState::Sunset) + { + Valid = true; + } + + if (!Valid) + { + ZEN_WARN("invalid session state transition {} -> {}", ToString(Current), ToString(NewState)); + return false; + } + + if (m_SessionState.compare_exchange_strong(Current, NewState, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: {} -> {}", ToString(Current), ToString(NewState)); + + if (NewState == SessionState::Abandoned) + { + AbandonAllActions(); + } + + return true; + } + + // CAS failed, Current was updated — retry with the new value + } +} + +void +ComputeServiceSession::Impl::AbandonAllActions() +{ + // Collect all pending actions and mark them as Abandoned + std::vector> PendingToAbandon; + + m_PendingLock.WithSharedLock([&] { + PendingToAbandon.reserve(m_PendingActions.size()); + for (auto& [Lsn, Action] : m_PendingActions) + { + PendingToAbandon.push_back(Action); + } + }); + + for (auto& Action : PendingToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + } + + // Collect all running actions and mark them as Abandoned, then + // best-effort cancel via the local runner group + std::vector> RunningToAbandon; + + m_RunningLock.WithSharedLock([&] { + RunningToAbandon.reserve(m_RunningMap.size()); + for (auto& [Lsn, Action] : m_RunningMap) + { + RunningToAbandon.push_back(Action); + } + }); + + for (auto& Action : RunningToAbandon) + { + Action->SetActionState(RunnerAction::State::Abandoned); + m_LocalRunnerGroup.CancelAction(Action->ActionLsn); + } + + ZEN_INFO("abandoned all actions: {} pending, {} running", PendingToAbandon.size(), RunningToAbandon.size()); +} + +void +ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_OrchestratorEndpoint = Endpoint; +} + +void +ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_OrchestratorBasePath = std::move(BasePath); +} + +void +ComputeServiceSession::Impl::UpdateCoordinatorState() +{ + ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); + if (m_OrchestratorEndpoint.empty()) + { + return; + } + + // Poll faster when we have no discovered workers yet so remote runners come online quickly + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } + + m_OrchestratorQueryTimer.Reset(); + + try + { + HttpClient Client(m_OrchestratorEndpoint); + + HttpClient::Response Response = Client.Get("/orch/agents"); + + if (!Response.IsSuccess()) + { + ZEN_WARN("orchestrator query failed with status {}", static_cast(Response.StatusCode)); + return; + } + + CbObject WorkerList = Response.AsObject(); + + std::unordered_set ValidWorkerUris; + + for (auto& Item : WorkerList["workers"sv]) + { + CbObjectView Worker = Item.AsObjectView(); + + uint64_t Dt = Worker["dt"sv].AsUInt64(); + bool Reachable = Worker["reachable"sv].AsBool(); + std::string_view Uri = Worker["uri"sv].AsString(); + + // Skip stale workers (not seen in over 30 seconds) + if (Dt > 30000) + { + continue; + } + + // Skip workers that are not confirmed reachable + if (!Reachable) + { + continue; + } + + std::string UriStr{Uri}; + ValidWorkerUris.insert(UriStr); + + // Skip workers we already know about + if (m_KnownWorkerUris.contains(UriStr)) + { + continue; + } + + ZEN_INFO("discovered new worker at {}", UriStr); + + m_KnownWorkerUris.insert(UriStr); + + auto* NewRunner = new RemoteHttpRunner(m_ChunkResolver, m_OrchestratorBasePath, UriStr, m_RemoteSubmitPool); + SyncWorkersToRunner(*NewRunner); + m_RemoteRunnerGroup.AddRunner(NewRunner); + } + + // Remove workers that are no longer valid (stale or unreachable) + for (auto It = m_KnownWorkerUris.begin(); It != m_KnownWorkerUris.end();) + { + if (!ValidWorkerUris.contains(*It)) + { + const std::string& ExpiredUri = *It; + ZEN_INFO("removing expired worker at {}", ExpiredUri); + + m_RemoteRunnerGroup.RemoveRunnerIf([&](const RemoteHttpRunner& Runner) { return Runner.GetHostName() == ExpiredUri; }); + + It = m_KnownWorkerUris.erase(It); + } + else + { + ++It; + } + } + } + catch (const HttpClientError& Ex) + { + ZEN_WARN("orchestrator query error: {}", Ex.what()); + } + catch (const std::exception& Ex) + { + ZEN_WARN("orchestrator query unexpected error: {}", Ex.what()); + } +} + +void +ComputeServiceSession::Impl::WaitUntilReady() +{ + if (m_RemoteRunnerGroup.GetRunnerCount() || !m_OrchestratorEndpoint.empty()) + { + ZEN_INFO("waiting for remote runners..."); + + constexpr int MaxWaitSeconds = 120; + + for (int Elapsed = 0; Elapsed < MaxWaitSeconds; Elapsed++) + { + if (!m_SchedulingThreadEnabled.load(std::memory_order_relaxed)) + { + ZEN_WARN("shutdown requested while waiting for remote runners"); + return; + } + + const size_t Capacity = m_RemoteRunnerGroup.QueryCapacity(); + + if (Capacity > 0) + { + ZEN_INFO("found {} remote runners (capacity: {})", m_RemoteRunnerGroup.GetRunnerCount(), Capacity); + break; + } + + zen::Sleep(1000); + } + } + else + { + ZEN_ASSERT(m_LocalRunnerGroup.GetRunnerCount(), "no runners available"); + } + + RequestStateTransition(SessionState::Ready); +} + +void +ComputeServiceSession::Impl::Shutdown() +{ + RequestStateTransition(SessionState::Sunset); + + m_SchedulingThreadEnabled = false; + m_SchedulingThreadEvent.Set(); + if (m_SchedulingThread.joinable()) + { + m_SchedulingThread.join(); + } + + ShutdownRunners(); + + m_DeferredDeleter.Shutdown(); +} + +void +ComputeServiceSession::Impl::ShutdownRunners() +{ + m_LocalRunnerGroup.Shutdown(); + m_RemoteRunnerGroup.Shutdown(); +} + +void +ComputeServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) +{ + ZEN_INFO("starting recording to '{}'", RecordingPath); + + m_Recorder = std::make_unique(InCidStore, RecordingPath); + + ZEN_INFO("started recording to '{}'", RecordingPath); +} + +void +ComputeServiceSession::Impl::StopRecording() +{ + ZEN_INFO("stopping recording"); + + m_Recorder = nullptr; + + ZEN_INFO("stopped recording"); +} + +std::vector +ComputeServiceSession::Impl::GetRunningActions() +{ + std::vector Result; + m_RunningLock.WithSharedLock([&] { + Result.reserve(m_RunningMap.size()); + for (const auto& [Lsn, Action] : m_RunningMap) + { + Result.push_back({.Lsn = Lsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .CpuUsagePercent = Action->CpuUsagePercent.load(std::memory_order_relaxed), + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed)}); + } + }); + return Result; +} + +std::vector +ComputeServiceSession::Impl::GetActionHistory(int Limit) +{ + RwLock::SharedLockScope _(m_ActionHistoryLock); + + if (Limit > 0 && static_cast(Limit) < m_ActionHistory.size()) + { + return std::vector(m_ActionHistory.end() - Limit, m_ActionHistory.end()); + } + + return std::vector(m_ActionHistory.begin(), m_ActionHistory.end()); +} + +std::vector +ComputeServiceSession::Impl::GetQueueHistory(int QueueId, int Limit) +{ + // Resolve the queue and snapshot its finished LSN set + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + std::unordered_set FinishedLsns; + + Queue->m_Lock.WithSharedLock([&] { FinishedLsns = Queue->FinishedLsns; }); + + // Filter the global history to entries belonging to this queue. + // m_ActionHistory is ordered oldest-first, so the filtered result keeps the same ordering. + std::vector Result; + + m_ActionHistoryLock.WithSharedLock([&] { + for (const auto& Entry : m_ActionHistory) + { + if (FinishedLsns.contains(Entry.Lsn)) + { + Result.push_back(Entry); + } + } + }); + + if (Limit > 0 && static_cast(Limit) < Result.size()) + { + Result.erase(Result.begin(), Result.end() - Limit); + } + + return Result; +} + +void +ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) +{ + ZEN_TRACE_CPU("ComputeServiceSession::RegisterWorker"); + RwLock::ExclusiveLockScope _(m_WorkerLock); + + const IoHash& WorkerId = Worker.GetObject().GetHash(); + + if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) + { + // Note that since the convention currently is that WorkerId is equal to the hash + // of the worker descriptor there is no chance that we get a second write with a + // different descriptor. Thus we only need to call this the first time, when the + // worker is added + + m_LocalRunnerGroup.RegisterWorker(Worker); + m_RemoteRunnerGroup.RegisterWorker(Worker); + + if (m_Recorder) + { + m_Recorder->RegisterWorker(Worker); + } + + CbObject WorkerObj = Worker.GetObject(); + + // Populate worker database + + const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerObj["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } + } +} + +void +ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) +{ + ZEN_TRACE_CPU("SyncWorkersToRunner"); + + std::vector Workers; + + { + RwLock::SharedLockScope _(m_WorkerLock); + Workers.reserve(m_WorkerMap.size()); + for (const auto& [Id, Pkg] : m_WorkerMap) + { + Workers.push_back(Pkg); + } + } + + for (const CbPackage& Worker : Workers) + { + Runner.RegisterWorker(Worker); + } +} + +WorkerDesc +ComputeServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) +{ + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + const CbPackage& Desc = It->second; + return {Desc, WorkerId}; + } + + return {}; +} + +std::vector +ComputeServiceSession::Impl::GetKnownWorkerIds() +{ + std::vector WorkerIds; + + m_WorkerLock.WithSharedLock([&] { + WorkerIds.reserve(m_WorkerMap.size()); + for (const auto& [WorkerId, _] : m_WorkerMap) + { + WorkerIds.push_back(WorkerId); + } + }); + + return WorkerIds; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueAction(int QueueId, CbObject ActionObject, int Priority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueAction"); + + // Resolve function to worker + + IoHash WorkerId{IoHash::Zero}; + CbPackage WorkerPackage; + + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + m_WorkerLock.WithSharedLock([&] { + for (const FunctionDefinition& FuncDef : m_FunctionList) + { + if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && + FuncDef.BuildSystemVersion == BuildSystemVersion) + { + WorkerId = FuncDef.WorkerId; + + break; + } + } + + if (WorkerId != IoHash::Zero) + { + if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) + { + WorkerPackage = It->second; + } + } + }); + + if (WorkerId == IoHash::Zero) + { + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker matches the action specification"; + + return {0, Writer.Save()}; + } + + if (WorkerPackage) + { + return EnqueueResolvedAction(QueueId, WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); + } + + CbObjectWriter Writer; + + Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; + Writer << "error" + << "no worker found despite match"; + + return {0, Writer.Save()}; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + ZEN_TRACE_CPU("ComputeServiceSession::EnqueueResolvedAction"); + + if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) + { + CbObjectWriter Writer; + Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); + return {0, Writer.Save()}; + } + + const int ActionLsn = ++m_ActionsCounter; + + m_ArrivalRate.Mark(); + + Ref Pending{new RunnerAction(m_ComputeServiceSession)}; + + Pending->ActionLsn = ActionLsn; + Pending->QueueId = QueueId; + Pending->Worker = Worker; + Pending->ActionId = ActionObj.GetHash(); + Pending->ActionObj = ActionObj; + Pending->Priority = RequestPriority; + + // For now simply put action into pending state, so we can do batch scheduling + + ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); + + Pending->SetActionState(RunnerAction::State::Pending); + + if (m_Recorder) + { + m_Recorder->RecordAction(Pending); + } + + CbObjectWriter Writer; + Writer << "lsn" << Pending->ActionLsn; + Writer << "worker" << Pending->Worker.WorkerId; + Writer << "action" << Pending->ActionId; + + return {Pending->ActionLsn, Writer.Save()}; +} + +SubmitResult +ComputeServiceSession::Impl::SubmitAction(Ref Action) +{ + // Loosely round-robin scheduling of actions across runners. + // + // It's not entirely clear what this means given that submits + // can come in across multiple threads, but it's probably better + // than always starting with the first runner. + // + // Longer term we should track the state of the individual + // runners and make decisions accordingly. + + SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); + if (Result.IsAccepted) + { + return Result; + } + + return m_RemoteRunnerGroup.SubmitAction(Action); +} + +size_t +ComputeServiceSession::Impl::GetSubmittedActionCount() +{ + return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); +} + +HttpResponseCode +ComputeServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) + { + return HttpResponseCode::Accepted; + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) + { + return HttpResponseCode::Accepted; + } + } + + return HttpResponseCode::NotFound; +} + +HttpResponseCode +ComputeServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + // This lock is held for the duration of the function since we need to + // be sure that the action doesn't change state while we are checking the + // different data structures + + RwLock::ExclusiveLockScope _(m_ResultsLock); + + for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) + { + if (It->second->ActionId == ActionId) + { + OutResultPackage = std::move(It->second->GetResult()); + + m_ResultsMap.erase(It); + + return HttpResponseCode::OK; + } + } + + { + RwLock::SharedLockScope __(m_PendingLock); + + for (const auto& [K, Pending] : m_PendingActions) + { + if (Pending->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must + // always be taken after m_ResultsLock if both are needed + + { + RwLock::SharedLockScope __(m_RunningLock); + + for (const auto& [K, v] : m_RunningMap) + { + if (v->ActionId == ActionId) + { + return HttpResponseCode::Accepted; + } + } + } + + return HttpResponseCode::NotFound; +} + +void +ComputeServiceSession::Impl::RetireActionResult(int ActionLsn) +{ + m_DeferredDeleter.MarkReady(ActionLsn); +} + +void +ComputeServiceSession::Impl::GetCompleted(CbWriter& Cbo) +{ + Cbo.BeginArray("completed"); + + m_ResultsLock.WithSharedLock([&] { + for (auto& [Lsn, Action] : m_ResultsMap) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Lsn; + Cbo << "state"sv << RunnerAction::ToString(Action->ActionState()); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); +} + +////////////////////////////////////////////////////////////////////////// +// Queue management + +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::Impl::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + const int QueueId = ++m_QueueCounter; + + Ref Queue{new QueueEntry()}; + Queue->QueueId = QueueId; + Queue->Tag = Tag; + Queue->Metadata = std::move(Metadata); + Queue->Config = std::move(Config); + Queue->IdleSince = GetHifreqTimerValue(); + + m_QueueLock.WithExclusiveLock([&] { m_Queues[QueueId] = Queue; }); + + ZEN_DEBUG("created queue {}", QueueId); + + return {.QueueId = QueueId}; +} + +std::vector +ComputeServiceSession::Impl::GetQueueIds() +{ + std::vector Ids; + + m_QueueLock.WithSharedLock([&] { + Ids.reserve(m_Queues.size()); + for (const auto& [Id, Queue] : m_Queues) + { + if (!Queue->Implicit) + { + Ids.push_back(Id); + } + } + }); + + return Ids; +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::Impl::GetQueueStatus(int QueueId) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + const int Active = Queue->ActiveCount.load(std::memory_order_relaxed); + const int Completed = Queue->CompletedCount.load(std::memory_order_relaxed); + const int Failed = Queue->FailedCount.load(std::memory_order_relaxed); + const int AbandonedN = Queue->AbandonedCount.load(std::memory_order_relaxed); + const int CancelledN = Queue->CancelledCount.load(std::memory_order_relaxed); + const QueueState QState = Queue->State.load(); + + return { + .IsValid = true, + .QueueId = QueueId, + .ActiveCount = Active, + .CompletedCount = Completed, + .FailedCount = Failed, + .AbandonedCount = AbandonedN, + .CancelledCount = CancelledN, + .State = QState, + .IsComplete = (Active == 0), + }; +} + +CbObject +ComputeServiceSession::Impl::GetQueueMetadata(int QueueId) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Metadata; +} + +CbObject +ComputeServiceSession::Impl::GetQueueConfig(int QueueId) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + return {}; + } + + return Queue->Config; +} + +void +ComputeServiceSession::Impl::CancelQueue(int QueueId) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + Queue->State.store(QueueState::Cancelled); + + // Collect active LSNs snapshot for cancellation + std::vector LsnsToCancel; + + Queue->m_Lock.WithSharedLock([&] { LsnsToCancel.assign(Queue->ActiveLsns.begin(), Queue->ActiveLsns.end()); }); + + // Identify which LSNs are still pending (not yet dispatched to a runner) + std::vector> PendingActionsToCancel; + std::vector RunningLsnsToCancel; + + m_PendingLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (auto It = m_PendingActions.find(Lsn); It != m_PendingActions.end()) + { + PendingActionsToCancel.push_back(It->second); + } + } + }); + + m_RunningLock.WithSharedLock([&] { + for (int Lsn : LsnsToCancel) + { + if (m_RunningMap.find(Lsn) != m_RunningMap.end()) + { + RunningLsnsToCancel.push_back(Lsn); + } + } + }); + + // Cancel pending actions by marking them as Cancelled; they will flow through + // HandleActionUpdates and eventually be removed from the pending map. + for (auto& Action : PendingActionsToCancel) + { + Action->SetActionState(RunnerAction::State::Cancelled); + } + + // Best-effort cancellation of running actions via the local runner group. + // Also set the action state to Cancelled directly so a subsequent Failed + // transition from the runner is blocked (Cancelled > Failed in the enum). + for (int Lsn : RunningLsnsToCancel) + { + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(Lsn); It != m_RunningMap.end()) + { + It->second->SetActionState(RunnerAction::State::Cancelled); + } + }); + m_LocalRunnerGroup.CancelAction(Lsn); + } + + m_RemoteRunnerGroup.CancelRemoteQueue(QueueId); + + ZEN_INFO("cancelled queue {}: {} pending, {} running actions cancelled", + QueueId, + PendingActionsToCancel.size(), + RunningLsnsToCancel.size()); + + // Wake up the scheduler to process the cancelled actions + m_SchedulingThreadEvent.Set(); +} + +void +ComputeServiceSession::Impl::DeleteQueue(int QueueId) +{ + // Never delete the implicit queue + { + Ref Queue = FindQueue(QueueId); + if (Queue && Queue->Implicit) + { + return; + } + } + + // Cancel any active work first + CancelQueue(QueueId); + + m_QueueLock.WithExclusiveLock([&] { + if (auto It = m_Queues.find(QueueId); It != m_Queues.end()) + { + m_Queues.erase(It); + } + }); +} + +void +ComputeServiceSession::Impl::DrainQueue(int QueueId) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue || Queue->Implicit) + { + return; + } + + QueueState Expected = QueueState::Active; + Queue->State.compare_exchange_strong(Expected, QueueState::Draining); + ZEN_INFO("draining queue {}", QueueId); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) +{ + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue not found"sv; + return {0, Writer.Save()}; + } + + QueueState QState = Queue->State.load(); + if (QState == QueueState::Cancelled) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is cancelled"sv; + return {0, Writer.Save()}; + } + + if (QState == QueueState::Draining) + { + CbObjectWriter Writer; + Writer << "error"sv + << "queue is draining"sv; + return {0, Writer.Save()}; + } + + EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); + + if (Result.Lsn != 0) + { + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + } + + return Result; +} + +void +ComputeServiceSession::Impl::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + Ref Queue = FindQueue(QueueId); + + Cbo.BeginArray("completed"); + + if (Queue) + { + Queue->m_Lock.WithSharedLock([&] { + m_ResultsLock.WithSharedLock([&] { + for (int Lsn : Queue->FinishedLsns) + { + if (m_ResultsMap.contains(Lsn)) + { + Cbo << Lsn; + } + } + }); + }); + } + + Cbo.EndArray(); +} + +void +ComputeServiceSession::Impl::NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState) +{ + if (QueueId == 0) + { + return; + } + + Ref Queue = FindQueue(QueueId); + + if (!Queue) + { + return; + } + + Queue->m_Lock.WithExclusiveLock([&] { + Queue->ActiveLsns.erase(Lsn); + Queue->FinishedLsns.insert(Lsn); + }); + + const int PreviousActive = Queue->ActiveCount.fetch_sub(1, std::memory_order_relaxed); + if (PreviousActive == 1) + { + Queue->IdleSince.store(GetHifreqTimerValue(), std::memory_order_relaxed); + } + + switch (ActionState) + { + case RunnerAction::State::Completed: + Queue->CompletedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Abandoned: + Queue->AbandonedCount.fetch_add(1, std::memory_order_relaxed); + break; + case RunnerAction::State::Cancelled: + Queue->CancelledCount.fetch_add(1, std::memory_order_relaxed); + break; + default: + Queue->FailedCount.fetch_add(1, std::memory_order_relaxed); + break; + } +} + +void +ComputeServiceSession::Impl::ExpireCompletedQueues() +{ + static constexpr uint64_t ExpiryTimeMs = 15 * 60 * 1000; + + std::vector ExpiredQueueIds; + + m_QueueLock.WithSharedLock([&] { + for (const auto& [Id, Queue] : m_Queues) + { + if (Queue->Implicit) + { + continue; + } + const uint64_t Idle = Queue->IdleSince.load(std::memory_order_relaxed); + if (Idle != 0 && Queue->ActiveCount.load(std::memory_order_relaxed) == 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(GetHifreqTimerValue() - Idle); + if (ElapsedMs >= ExpiryTimeMs) + { + ExpiredQueueIds.push_back(Id); + } + } + } + }); + + for (int QueueId : ExpiredQueueIds) + { + ZEN_INFO("expiring idle queue {}", QueueId); + DeleteQueue(QueueId); + } +} + +void +ComputeServiceSession::Impl::SchedulePendingActions() +{ + ZEN_TRACE_CPU("ComputeServiceSession::SchedulePendingActions"); + int ScheduledCount = 0; + size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); + + static Stopwatch DumpRunningTimer; + + auto _ = MakeGuard([&] { + ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", + ScheduledCount, + RunningCount, + m_RetiredCount.load(), + PendingCount, + ResultCount); + + if (DumpRunningTimer.GetElapsedTimeMs() > 30000) + { + DumpRunningTimer.Reset(); + + std::set RunningList; + m_RunningLock.WithSharedLock([&] { + for (auto& [K, V] : m_RunningMap) + { + RunningList.insert(K); + } + }); + + ExtendableStringBuilder<1024> RunningString; + for (int i : RunningList) + { + if (RunningString.Size()) + { + RunningString << ", "; + } + + RunningString.Append(IntNum(i)); + } + + ZEN_INFO("running: {}", RunningString); + } + }); + + size_t Capacity = QueryCapacity(); + + if (!Capacity) + { + _.Dismiss(); + + return; + } + + std::vector> ActionsToSchedule; + + // Pull actions to schedule from the pending queue, we will + // try to submit these to the runner outside of the lock. Note + // that because of how the state transitions work it's not + // actually the case that all of these actions will still be + // pending by the time we try to submit them, but that's fine. + // + // Also note that the m_PendingActions list is not maintained + // here, that's done periodically in SchedulePendingActions() + + m_PendingLock.WithExclusiveLock([&] { + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Paused) + { + return; + } + + if (m_PendingActions.empty()) + { + return; + } + + for (auto& [Lsn, Pending] : m_PendingActions) + { + switch (Pending->ActionState()) + { + case RunnerAction::State::Pending: + ActionsToSchedule.push_back(Pending); + break; + + case RunnerAction::State::Submitting: + break; // already claimed by async submission + + case RunnerAction::State::Running: + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + break; + + default: + case RunnerAction::State::New: + ZEN_WARN("unexpected state {} for pending action {}", static_cast(Pending->ActionState()), Pending->ActionLsn); + break; + } + } + + // Sort by priority descending, then by LSN ascending (FIFO within same priority) + std::sort(ActionsToSchedule.begin(), ActionsToSchedule.end(), [](const Ref& A, const Ref& B) { + if (A->Priority != B->Priority) + { + return A->Priority > B->Priority; + } + return A->ActionLsn < B->ActionLsn; + }); + + if (ActionsToSchedule.size() > Capacity) + { + ActionsToSchedule.resize(Capacity); + } + + PendingCount = m_PendingActions.size(); + }); + + if (ActionsToSchedule.empty()) + { + _.Dismiss(); + return; + } + + ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); + + Stopwatch SubmitTimer; + std::vector SubmitResults = SubmitActions(ActionsToSchedule); + + int NotAcceptedCount = 0; + int ScheduledActionCount = 0; + + for (const SubmitResult& SubResult : SubmitResults) + { + if (SubResult.IsAccepted) + { + ++ScheduledActionCount; + } + else + { + ++NotAcceptedCount; + } + } + + ZEN_INFO("scheduled {} pending actions in {} ({} rejected)", + ScheduledActionCount, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + NotAcceptedCount); + + ScheduledCount += ScheduledActionCount; + PendingCount -= ScheduledActionCount; +} + +void +ComputeServiceSession::Impl::SchedulerThreadFunction() +{ + SetCurrentThreadName("Function_Scheduler"); + + auto _ = MakeGuard([&] { ZEN_INFO("scheduler thread exiting"); }); + + do + { + int TimeoutMs = 500; + + auto PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + + if (PendingCount) + { + TimeoutMs = 100; + } + + const bool WasSignaled = m_SchedulingThreadEvent.Wait(TimeoutMs); + + if (m_SchedulingThreadEnabled == false) + { + return; + } + + if (WasSignaled) + { + m_SchedulingThreadEvent.Reset(); + } + + ZEN_DEBUG("compute scheduler TICK (Pending: {} was {}, Running: {}, Results: {}) timeout: {}", + m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }), + PendingCount, + m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }), + m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }), + TimeoutMs); + + HandleActionUpdates(); + + // Auto-transition Draining → Paused when all work is done + if (m_SessionState.load(std::memory_order_relaxed) == SessionState::Draining) + { + size_t Pending = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); + size_t Running = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); + + if (Pending == 0 && Running == 0) + { + SessionState Expected = SessionState::Draining; + if (m_SessionState.compare_exchange_strong(Expected, SessionState::Paused, std::memory_order_acq_rel)) + { + ZEN_INFO("session state: Draining -> Paused (all work completed)"); + } + } + } + + UpdateCoordinatorState(); + SchedulePendingActions(); + + static constexpr uint64_t QueueExpirySweepIntervalMs = 30000; + if (m_QueueExpiryTimer.GetElapsedTimeMs() >= QueueExpirySweepIntervalMs) + { + m_QueueExpiryTimer.Reset(); + ExpireCompletedQueues(); + } + } while (m_SchedulingThreadEnabled); +} + +void +ComputeServiceSession::Impl::PostUpdate(RunnerAction* Action) +{ + m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); + m_SchedulingThreadEvent.Set(); +} + +int +ComputeServiceSession::Impl::GetMaxRetriesForQueue(int QueueId) +{ + if (QueueId == 0) + { + return kDefaultMaxRetries; + } + + CbObject Config = GetQueueConfig(QueueId); + + if (Config) + { + int Value = Config["max_retries"].AsInt32(0); + + if (Value > 0) + { + return Value; + } + } + + return kDefaultMaxRetries; +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) +{ + Ref Action; + RunnerAction::State State; + RescheduleResult ValidationError; + bool Removed = false; + + // Find, validate, and remove atomically under a single lock scope to prevent + // concurrent RescheduleAction calls from double-removing the same action. + m_ResultsLock.WithExclusiveLock([&] { + auto It = m_ResultsMap.find(ActionLsn); + if (It == m_ResultsMap.end()) + { + ValidationError = {.Success = false, .Error = "Action not found in results"}; + return; + } + + Action = It->second; + State = Action->ActionState(); + + if (State != RunnerAction::State::Failed && State != RunnerAction::State::Abandoned) + { + ValidationError = {.Success = false, .Error = "Action is not in a failed or abandoned state"}; + return; + } + + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + if (Action->RetryCount.load(std::memory_order_relaxed) >= MaxRetries) + { + ValidationError = {.Success = false, .Error = "Retry limit reached"}; + return; + } + + m_ResultsMap.erase(It); + Removed = true; + }); + + if (!Removed) + { + return ValidationError; + } + + if (Action->QueueId != 0) + { + Ref Queue = FindQueue(Action->QueueId); + + if (Queue) + { + Queue->m_Lock.WithExclusiveLock([&] { + Queue->FinishedLsns.erase(ActionLsn); + Queue->ActiveLsns.insert(ActionLsn); + }); + + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); + + if (State == RunnerAction::State::Failed) + { + Queue->FailedCount.fetch_sub(1, std::memory_order_relaxed); + } + else + { + Queue->AbandonedCount.fetch_sub(1, std::memory_order_relaxed); + } + } + } + + // Reset action state — this calls PostUpdate() internally + Action->ResetActionStateToPending(); + + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + ZEN_INFO("action {} ({}) manually rescheduled (retry {})", Action->ActionId, ActionLsn, NewRetryCount); + + return {.Success = true, .RetryCount = NewRetryCount}; +} + +void +ComputeServiceSession::Impl::HandleActionUpdates() +{ + ZEN_TRACE_CPU("ComputeServiceSession::HandleActionUpdates"); + + // Drain the update queue atomically + std::vector> UpdatedActions; + m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); + + std::unordered_set SeenLsn; + + // Process each action's latest state, deduplicating by LSN. + // + // This is safe because state transitions are monotonically increasing by enum + // rank (Pending < Submitting < Running < Completed/Failed/Cancelled), so + // SetActionState rejects any transition to a lower-ranked state. By the time + // we read ActionState() here, it reflects the highest state reached — making + // the first occurrence per LSN authoritative and duplicates redundant. + for (Ref& Action : UpdatedActions) + { + const int ActionLsn = Action->ActionLsn; + + if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) + { + switch (Action->ActionState()) + { + // Newly enqueued — add to pending map for scheduling + case RunnerAction::State::Pending: + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + break; + + // Async submission in progress — remains in pending map + case RunnerAction::State::Submitting: + break; + + // Dispatched to a runner — move from pending to running + case RunnerAction::State::Running: + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + m_RunningMap.insert({ActionLsn, Action}); + m_PendingActions.erase(ActionLsn); + }); + }); + ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); + break; + + // Terminal states — move to results, record history, notify queue + case RunnerAction::State::Completed: + case RunnerAction::State::Failed: + case RunnerAction::State::Abandoned: + case RunnerAction::State::Cancelled: + { + auto TerminalState = Action->ActionState(); + + // Automatic retry for Failed/Abandoned actions with retries remaining. + // Skip retries when the session itself is abandoned — those actions + // were intentionally abandoned and should not be rescheduled. + if ((TerminalState == RunnerAction::State::Failed || TerminalState == RunnerAction::State::Abandoned) && + m_SessionState.load(std::memory_order_relaxed) < SessionState::Abandoned) + { + int MaxRetries = GetMaxRetriesForQueue(Action->QueueId); + + if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) + { + // Remove from whichever active map the action is in before resetting + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + // Reset triggers PostUpdate() which re-enters the action as Pending + Action->ResetActionStateToPending(); + int NewRetryCount = Action->RetryCount.load(std::memory_order_relaxed); + + ZEN_INFO("action {} ({}) auto-rescheduled (retry {}/{})", + Action->ActionId, + ActionLsn, + NewRetryCount, + MaxRetries); + break; + } + } + + // Remove from whichever active map the action is in + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); + + m_ResultsLock.WithExclusiveLock([&] { + m_ResultsMap[ActionLsn] = Action; + + // Append to bounded action history ring + m_ActionHistoryLock.WithExclusiveLock([&] { + ActionHistoryEntry Entry{.Lsn = ActionLsn, + .QueueId = Action->QueueId, + .ActionId = Action->ActionId, + .WorkerId = Action->Worker.WorkerId, + .ActionDescriptor = Action->ActionObj, + .ExecutionLocation = std::move(Action->ExecutionLocation), + .Succeeded = TerminalState == RunnerAction::State::Completed, + .CpuSeconds = Action->CpuSeconds.load(std::memory_order_relaxed), + .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; + + std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); + + m_ActionHistory.push_back(std::move(Entry)); + + if (m_ActionHistory.size() > m_HistoryLimit) + { + m_ActionHistory.pop_front(); + } + }); + }); + m_RetiredCount.fetch_add(1); + m_ResultRate.Mark(1); + ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", + Action->ActionId, + ActionLsn, + TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); + break; + } + } + } + } +} + +size_t +ComputeServiceSession::Impl::QueryCapacity() +{ + return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); +} + +std::vector +ComputeServiceSession::Impl::SubmitActions(const std::vector>& Actions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::SubmitActions"); + std::vector Results(Actions.size()); + + // First try submitting the batch to local runners in parallel + + std::vector LocalResults = m_LocalRunnerGroup.SubmitActions(Actions); + std::vector RemoteIndices; + std::vector> RemoteActions; + + for (size_t i = 0; i < Actions.size(); ++i) + { + if (LocalResults[i].IsAccepted) + { + Results[i] = std::move(LocalResults[i]); + } + else + { + RemoteIndices.push_back(i); + RemoteActions.push_back(Actions[i]); + } + } + + // Submit remaining actions to remote runners in parallel + if (!RemoteActions.empty()) + { + std::vector RemoteResults = m_RemoteRunnerGroup.SubmitActions(RemoteActions); + + for (size_t j = 0; j < RemoteIndices.size(); ++j) + { + Results[RemoteIndices[j]] = std::move(RemoteResults[j]); + } + } + + return Results; +} + +////////////////////////////////////////////////////////////////////////// + +ComputeServiceSession::ComputeServiceSession(ChunkResolver& InChunkResolver) +{ + m_Impl = std::make_unique(this, InChunkResolver); +} + +ComputeServiceSession::~ComputeServiceSession() +{ + Shutdown(); +} + +bool +ComputeServiceSession::IsHealthy() +{ + return m_Impl->IsHealthy(); +} + +void +ComputeServiceSession::WaitUntilReady() +{ + m_Impl->WaitUntilReady(); +} + +void +ComputeServiceSession::Shutdown() +{ + m_Impl->Shutdown(); +} + +ComputeServiceSession::SessionState +ComputeServiceSession::GetSessionState() const +{ + return m_Impl->m_SessionState.load(std::memory_order_relaxed); +} + +bool +ComputeServiceSession::RequestStateTransition(SessionState NewState) +{ + return m_Impl->RequestStateTransition(NewState); +} + +void +ComputeServiceSession::SetOrchestratorEndpoint(std::string_view Endpoint) +{ + m_Impl->SetOrchestratorEndpoint(Endpoint); +} + +void +ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) +{ + m_Impl->SetOrchestratorBasePath(std::move(BasePath)); +} + +void +ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) +{ + m_Impl->StartRecording(InResolver, RecordingPath); +} + +void +ComputeServiceSession::StopRecording() +{ + m_Impl->StopRecording(); +} + +ComputeServiceSession::ActionCounts +ComputeServiceSession::GetActionCounts() +{ + return m_Impl->GetActionCounts(); +} + +void +ComputeServiceSession::EmitStats(CbObjectWriter& Cbo) +{ + m_Impl->EmitStats(Cbo); +} + +std::vector +ComputeServiceSession::GetKnownWorkerIds() +{ + return m_Impl->GetKnownWorkerIds(); +} + +WorkerDesc +ComputeServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) +{ + return m_Impl->GetWorkerDescriptor(WorkerId); +} + +void +ComputeServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddLocalRunner"); + +# if ZEN_PLATFORM_LINUX + auto* NewRunner = new LinuxProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_WINDOWS + auto* NewRunner = new WindowsProcessRunner(InChunkResolver, + BasePath, + m_Impl->m_DeferredDeleter, + m_Impl->m_LocalSubmitPool, + false, + MaxConcurrentActions); +# elif ZEN_PLATFORM_MAC + auto* NewRunner = + new MacProcessRunner(InChunkResolver, BasePath, m_Impl->m_DeferredDeleter, m_Impl->m_LocalSubmitPool, false, MaxConcurrentActions); +# endif + + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_LocalRunnerGroup.AddRunner(NewRunner); +} + +void +ComputeServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) +{ + ZEN_TRACE_CPU("ComputeServiceSession::AddRemoteRunner"); + + auto* NewRunner = new RemoteHttpRunner(InChunkResolver, BasePath, HostName, m_Impl->m_RemoteSubmitPool); + m_Impl->SyncWorkersToRunner(*NewRunner); + m_Impl->m_RemoteRunnerGroup.AddRunner(NewRunner); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueAction(CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(m_Impl->m_ImplicitQueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(m_Impl->m_ImplicitQueueId, Worker, ActionObj, RequestPriority); +} +ComputeServiceSession::CreateQueueResult +ComputeServiceSession::CreateQueue(std::string_view Tag, CbObject Metadata, CbObject Config) +{ + return m_Impl->CreateQueue(Tag, std::move(Metadata), std::move(Config)); +} + +CbObject +ComputeServiceSession::GetQueueMetadata(int QueueId) +{ + return m_Impl->GetQueueMetadata(QueueId); +} + +CbObject +ComputeServiceSession::GetQueueConfig(int QueueId) +{ + return m_Impl->GetQueueConfig(QueueId); +} + +std::vector +ComputeServiceSession::GetQueueIds() +{ + return m_Impl->GetQueueIds(); +} + +ComputeServiceSession::QueueStatus +ComputeServiceSession::GetQueueStatus(int QueueId) +{ + return m_Impl->GetQueueStatus(QueueId); +} + +void +ComputeServiceSession::CancelQueue(int QueueId) +{ + m_Impl->CancelQueue(QueueId); +} + +void +ComputeServiceSession::DrainQueue(int QueueId) +{ + m_Impl->DrainQueue(QueueId); +} + +void +ComputeServiceSession::DeleteQueue(int QueueId) +{ + m_Impl->DeleteQueue(QueueId); +} + +void +ComputeServiceSession::GetQueueCompleted(int QueueId, CbWriter& Cbo) +{ + m_Impl->GetQueueCompleted(QueueId, Cbo); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + return m_Impl->EnqueueActionToQueue(QueueId, ActionObject, Priority); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int RequestPriority) +{ + return m_Impl->EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority); +} + +void +ComputeServiceSession::RegisterWorker(CbPackage Worker) +{ + m_Impl->RegisterWorker(Worker); +} + +HttpResponseCode +ComputeServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) +{ + return m_Impl->GetActionResult(ActionLsn, OutResultPackage); +} + +HttpResponseCode +ComputeServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) +{ + return m_Impl->FindActionResult(ActionId, OutResultPackage); +} + +void +ComputeServiceSession::RetireActionResult(int ActionLsn) +{ + m_Impl->RetireActionResult(ActionLsn); +} + +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RescheduleAction(int ActionLsn) +{ + return m_Impl->RescheduleAction(ActionLsn); +} + +std::vector +ComputeServiceSession::GetRunningActions() +{ + return m_Impl->GetRunningActions(); +} + +std::vector +ComputeServiceSession::GetActionHistory(int Limit) +{ + return m_Impl->GetActionHistory(Limit); +} + +std::vector +ComputeServiceSession::GetQueueHistory(int QueueId, int Limit) +{ + return m_Impl->GetQueueHistory(QueueId, Limit); +} + +void +ComputeServiceSession::GetCompleted(CbWriter& Cbo) +{ + m_Impl->GetCompleted(Cbo); +} + +void +ComputeServiceSession::PostUpdate(RunnerAction* Action) +{ + m_Impl->PostUpdate(Action); +} + +////////////////////////////////////////////////////////////////////////// + +void +computeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/functionrunner.cpp b/src/zencompute/functionrunner.cpp deleted file mode 100644 index 8e7c12b2b..000000000 --- a/src/zencompute/functionrunner.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "functionrunner.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include - -# include -# include - -namespace zen::compute { - -FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") -{ -} - -FunctionRunner::~FunctionRunner() = default; - -size_t -FunctionRunner::QueryCapacity() -{ - return 1; -} - -std::vector -FunctionRunner::SubmitActions(const std::vector>& Actions) -{ - std::vector Results; - Results.reserve(Actions.size()); - - for (const Ref& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -void -FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) -{ - if (m_DumpActions) - { - std::string UniqueId = fmt::format("{}.ddb", ActionLsn); - std::filesystem::path Path = m_ActionsPath / UniqueId; - - zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); - } -} - -////////////////////////////////////////////////////////////////////////// - -RunnerAction::RunnerAction(FunctionServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) -{ - this->Timestamps[static_cast(State::New)] = DateTime::Now().GetTicks(); -} - -RunnerAction::~RunnerAction() -{ -} - -void -RunnerAction::SetActionState(State NewState) -{ - ZEN_ASSERT(NewState < State::_Count); - this->Timestamps[static_cast(NewState)] = DateTime::Now().GetTicks(); - - do - { - if (State CurrentState = m_ActionState.load(); CurrentState == NewState) - { - // No state change - return; - } - else - { - if (NewState <= CurrentState) - { - // Cannot transition to an earlier or same state - return; - } - - if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) - { - // Successful state change - - m_OwnerSession->PostUpdate(this); - - return; - } - } - } while (true); -} - -void -RunnerAction::SetResult(CbPackage&& Result) -{ - m_Result = std::move(Result); -} - -CbPackage& -RunnerAction::GetResult() -{ - ZEN_ASSERT(IsCompleted()); - return m_Result; -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/functionrunner.h b/src/zencompute/functionrunner.h deleted file mode 100644 index 6fd0d84cc..000000000 --- a/src/zencompute/functionrunner.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include - -namespace zen::compute { - -struct SubmitResult -{ - bool IsAccepted = false; - std::string Reason; -}; - -/** Base interface for classes implementing a remote execution "runner" - */ -class FunctionRunner : public RefCounted -{ - FunctionRunner(FunctionRunner&&) = delete; - FunctionRunner& operator=(FunctionRunner&&) = delete; - -public: - FunctionRunner(std::filesystem::path BasePath); - virtual ~FunctionRunner() = 0; - - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; - - [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) = 0; - [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; - [[nodiscard]] virtual bool IsHealthy() = 0; - [[nodiscard]] virtual size_t QueryCapacity(); - [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions); - -protected: - std::filesystem::path m_ActionsPath; - bool m_DumpActions = false; - void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); -}; - -template -struct RunnerGroup -{ - void AddRunner(RunnerType* Runner) - { - m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); - } - size_t QueryCapacity() - { - size_t TotalCapacity = 0; - m_RunnersLock.WithSharedLock([&] { - for (const auto& Runner : m_Runners) - { - TotalCapacity += Runner->QueryCapacity(); - } - }); - return TotalCapacity; - } - - SubmitResult SubmitAction(Ref Action) - { - RwLock::SharedLockScope _(m_RunnersLock); - - const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); - int Index = InitialIndex; - const int RunnerCount = gsl::narrow(m_Runners.size()); - - if (RunnerCount == 0) - { - return {.IsAccepted = false, .Reason = "No runners available"}; - } - - do - { - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - - auto& Runner = m_Runners[Index++]; - - SubmitResult Result = Runner->SubmitAction(Action); - - if (Result.IsAccepted == true) - { - m_NextSubmitIndex = Index % RunnerCount; - - return Result; - } - - while (Index >= RunnerCount) - { - Index -= RunnerCount; - } - } while (Index != InitialIndex); - - return {.IsAccepted = false}; - } - - size_t GetSubmittedActionCount() - { - RwLock::SharedLockScope _(m_RunnersLock); - - size_t TotalCount = 0; - - for (const auto& Runner : m_Runners) - { - TotalCount += Runner->GetSubmittedActionCount(); - } - - return TotalCount; - } - - void RegisterWorker(CbPackage Worker) - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->RegisterWorker(Worker); - } - } - - void Shutdown() - { - RwLock::SharedLockScope _(m_RunnersLock); - - for (auto& Runner : m_Runners) - { - Runner->Shutdown(); - } - } - -private: - RwLock m_RunnersLock; - std::vector> m_Runners; - std::atomic m_NextSubmitIndex{0}; -}; - -/** - * This represents an action going through different stages of scheduling and execution. - */ -struct RunnerAction : public RefCounted -{ - explicit RunnerAction(FunctionServiceSession* OwnerSession); - ~RunnerAction(); - - int ActionLsn = 0; - WorkerDesc Worker; - IoHash ActionId; - CbObject ActionObj; - int Priority = 0; - - enum class State - { - New, - Pending, - Running, - Completed, - Failed, - _Count - }; - - static const char* ToString(State _) - { - switch (_) - { - case State::New: - return "New"; - case State::Pending: - return "Pending"; - case State::Running: - return "Running"; - case State::Completed: - return "Completed"; - case State::Failed: - return "Failed"; - default: - return "Unknown"; - } - } - - uint64_t Timestamps[static_cast(State::_Count)] = {}; - - State ActionState() const { return m_ActionState; } - void SetActionState(State NewState); - - bool IsSuccess() const { return ActionState() == State::Completed; } - bool IsCompleted() const { return ActionState() == State::Completed || ActionState() == State::Failed; } - - void SetResult(CbPackage&& Result); - CbPackage& GetResult(); - -private: - std::atomic m_ActionState = State::New; - FunctionServiceSession* m_OwnerSession = nullptr; - CbPackage m_Result; -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/functionservice.cpp b/src/zencompute/functionservice.cpp deleted file mode 100644 index 0698449e9..000000000 --- a/src/zencompute/functionservice.cpp +++ /dev/null @@ -1,957 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/functionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" -# include "actionrecorder.h" -# include "localrunner.h" -# include "remotehttprunner.h" - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include -# include -# include -# include -# include - -ZEN_THIRD_PARTY_INCLUDES_START -# include -ZEN_THIRD_PARTY_INCLUDES_END - -using namespace std::literals; - -namespace zen::compute { - -////////////////////////////////////////////////////////////////////////// - -struct FunctionServiceSession::Impl -{ - FunctionServiceSession* m_FunctionServiceSession; - ChunkResolver& m_ChunkResolver; - LoggerRef m_Log{logging::Get("apply")}; - - Impl(FunctionServiceSession* InFunctionServiceSession, ChunkResolver& InChunkResolver) - : m_FunctionServiceSession(InFunctionServiceSession) - , m_ChunkResolver(InChunkResolver) - { - m_SchedulingThread = std::thread{&Impl::MonitorThreadFunction, this}; - } - - void Shutdown(); - bool IsHealthy(); - - LoggerRef Log() { return m_Log; } - - std::atomic_bool m_AcceptActions = true; - - struct FunctionDefinition - { - std::string FunctionName; - Guid FunctionVersion; - Guid BuildSystemVersion; - IoHash WorkerId; - }; - - void EmitStats(CbObjectWriter& Cbo) - { - m_WorkerLock.WithSharedLock([&] { Cbo << "worker_count"sv << m_WorkerMap.size(); }); - m_ResultsLock.WithSharedLock([&] { Cbo << "actions_complete"sv << m_ResultsMap.size(); }); - m_PendingLock.WithSharedLock([&] { Cbo << "actions_pending"sv << m_PendingActions.size(); }); - Cbo << "actions_submitted"sv << GetSubmittedActionCount(); - EmitSnapshot("actions_retired"sv, m_ResultRate, Cbo); - } - - void RegisterWorker(CbPackage Worker); - WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - - std::atomic m_ActionsCounter = 0; // sequence number - - RwLock m_PendingLock; - std::map> m_PendingActions; - - RwLock m_RunningLock; - std::unordered_map> m_RunningMap; - - RwLock m_ResultsLock; - std::unordered_map> m_ResultsMap; - metrics::Meter m_ResultRate; - std::atomic m_RetiredCount{0}; - - HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - std::atomic m_ShutdownRequested{false}; - - std::thread m_SchedulingThread; - std::atomic m_SchedulingThreadEnabled{true}; - Event m_SchedulingThreadEvent; - - void MonitorThreadFunction(); - void SchedulePendingActions(); - - // Workers - - RwLock m_WorkerLock; - std::unordered_map m_WorkerMap; - std::vector m_FunctionList; - std::vector GetKnownWorkerIds(); - - // Runners - - RunnerGroup m_LocalRunnerGroup; - RunnerGroup m_RemoteRunnerGroup; - - EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority); - - void GetCompleted(CbWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath); - void StopRecording(); - - std::unique_ptr m_Recorder; - - // History tracking - - RwLock m_ActionHistoryLock; - std::deque m_ActionHistory; - size_t m_HistoryLimit = 1000; - - std::vector GetActionHistory(int Limit); - - // - - [[nodiscard]] size_t QueryCapacity(); - - [[nodiscard]] SubmitResult SubmitAction(Ref Action); - [[nodiscard]] std::vector SubmitActions(const std::vector>& Actions); - [[nodiscard]] size_t GetSubmittedActionCount(); - - // Updates - - RwLock m_UpdatedActionsLock; - std::vector> m_UpdatedActions; - - void HandleActionUpdates(); - void PostUpdate(RunnerAction* Action); - - void ShutdownRunners(); -}; - -bool -FunctionServiceSession::Impl::IsHealthy() -{ - return true; -} - -void -FunctionServiceSession::Impl::Shutdown() -{ - m_AcceptActions = false; - m_ShutdownRequested = true; - - m_SchedulingThreadEnabled = false; - m_SchedulingThreadEvent.Set(); - if (m_SchedulingThread.joinable()) - { - m_SchedulingThread.join(); - } - - ShutdownRunners(); -} - -void -FunctionServiceSession::Impl::ShutdownRunners() -{ - m_LocalRunnerGroup.Shutdown(); - m_RemoteRunnerGroup.Shutdown(); -} - -void -FunctionServiceSession::Impl::StartRecording(ChunkResolver& InCidStore, const std::filesystem::path& RecordingPath) -{ - ZEN_INFO("starting recording to '{}'", RecordingPath); - - m_Recorder = std::make_unique(InCidStore, RecordingPath); - - ZEN_INFO("started recording to '{}'", RecordingPath); -} - -void -FunctionServiceSession::Impl::StopRecording() -{ - ZEN_INFO("stopping recording"); - - m_Recorder = nullptr; - - ZEN_INFO("stopped recording"); -} - -std::vector -FunctionServiceSession::Impl::GetActionHistory(int Limit) -{ - RwLock::SharedLockScope _(m_ActionHistoryLock); - - if (Limit > 0 && static_cast(Limit) < m_ActionHistory.size()) - { - return std::vector(m_ActionHistory.end() - Limit, m_ActionHistory.end()); - } - - return std::vector(m_ActionHistory.begin(), m_ActionHistory.end()); -} - -void -FunctionServiceSession::Impl::RegisterWorker(CbPackage Worker) -{ - RwLock::ExclusiveLockScope _(m_WorkerLock); - - const IoHash& WorkerId = Worker.GetObject().GetHash(); - - if (m_WorkerMap.insert_or_assign(WorkerId, Worker).second) - { - // Note that since the convention currently is that WorkerId is equal to the hash - // of the worker descriptor there is no chance that we get a second write with a - // different descriptor. Thus we only need to call this the first time, when the - // worker is added - - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); - - if (m_Recorder) - { - m_Recorder->RegisterWorker(Worker); - } - - CbObject WorkerObj = Worker.GetObject(); - - // Populate worker database - - const Guid WorkerBuildSystemVersion = WorkerObj["buildsystem_version"sv].AsUuid(); - - for (auto& Item : WorkerObj["functions"sv]) - { - CbObjectView Function = Item.AsObjectView(); - - std::string_view FunctionName = Function["name"sv].AsString(); - const Guid FunctionVersion = Function["version"sv].AsUuid(); - - m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, - .FunctionVersion = FunctionVersion, - .BuildSystemVersion = WorkerBuildSystemVersion, - .WorkerId = WorkerId}); - } - } -} - -WorkerDesc -FunctionServiceSession::Impl::GetWorkerDescriptor(const IoHash& WorkerId) -{ - RwLock::SharedLockScope _(m_WorkerLock); - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - const CbPackage& Desc = It->second; - return {Desc, WorkerId}; - } - - return {}; -} - -std::vector -FunctionServiceSession::Impl::GetKnownWorkerIds() -{ - std::vector WorkerIds; - WorkerIds.reserve(m_WorkerMap.size()); - - m_WorkerLock.WithSharedLock([&] { - for (const auto& [WorkerId, _] : m_WorkerMap) - { - WorkerIds.push_back(WorkerId); - } - }); - - return WorkerIds; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueAction(CbObject ActionObject, int Priority) -{ - // Resolve function to worker - - IoHash WorkerId{IoHash::Zero}; - - std::string_view FunctionName = ActionObject["Function"sv].AsString(); - const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); - const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); - - for (const FunctionDefinition& FuncDef : m_FunctionList) - { - if (FuncDef.FunctionName == FunctionName && FuncDef.FunctionVersion == FunctionVersion && - FuncDef.BuildSystemVersion == BuildSystemVersion) - { - WorkerId = FuncDef.WorkerId; - - break; - } - } - - if (WorkerId == IoHash::Zero) - { - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker matches the action specification"; - - return {0, Writer.Save()}; - } - - if (auto It = m_WorkerMap.find(WorkerId); It != m_WorkerMap.end()) - { - CbPackage WorkerPackage = It->second; - - return EnqueueResolvedAction(WorkerDesc{WorkerPackage, WorkerId}, ActionObject, Priority); - } - - CbObjectWriter Writer; - - Writer << "Function"sv << FunctionName << "FunctionVersion"sv << FunctionVersion << "BuildSystemVersion" << BuildSystemVersion; - Writer << "error" - << "no worker found despite match"; - - return {0, Writer.Save()}; -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::Impl::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - const int ActionLsn = ++m_ActionsCounter; - - Ref Pending{new RunnerAction(m_FunctionServiceSession)}; - - Pending->ActionLsn = ActionLsn; - Pending->Worker = Worker; - Pending->ActionId = ActionObj.GetHash(); - Pending->ActionObj = ActionObj; - Pending->Priority = RequestPriority; - - SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - ZEN_DEBUG("direct schedule LSN {}", Pending->ActionLsn); - } - else - { - ZEN_DEBUG("action {} ({}) PENDING", Pending->ActionId, Pending->ActionLsn); - - Pending->SetActionState(RunnerAction::State::Pending); - } - - if (m_Recorder) - { - m_Recorder->RecordAction(Pending); - } - - CbObjectWriter Writer; - Writer << "lsn" << Pending->ActionLsn; - Writer << "worker" << Pending->Worker.WorkerId; - Writer << "action" << Pending->ActionId; - - return {Pending->ActionLsn, Writer.Save()}; -} - -SubmitResult -FunctionServiceSession::Impl::SubmitAction(Ref Action) -{ - // Loosely round-robin scheduling of actions across runners. - // - // It's not entirely clear what this means given that submits - // can come in across multiple threads, but it's probably better - // than always starting with the first runner. - // - // Longer term we should track the state of the individual - // runners and make decisions accordingly. - - SubmitResult Result = m_LocalRunnerGroup.SubmitAction(Action); - if (Result.IsAccepted) - { - return Result; - } - - return m_RemoteRunnerGroup.SubmitAction(Action); -} - -size_t -FunctionServiceSession::Impl::GetSubmittedActionCount() -{ - return m_LocalRunnerGroup.GetSubmittedActionCount() + m_RemoteRunnerGroup.GetSubmittedActionCount(); -} - -HttpResponseCode -FunctionServiceSession::Impl::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - if (auto It = m_ResultsMap.find(ActionLsn); It != m_ResultsMap.end()) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - if (auto FindIt = m_PendingActions.find(ActionLsn); FindIt != m_PendingActions.end()) - { - return HttpResponseCode::Accepted; - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - if (m_RunningMap.find(ActionLsn) != m_RunningMap.end()) - { - return HttpResponseCode::Accepted; - } - } - - return HttpResponseCode::NotFound; -} - -HttpResponseCode -FunctionServiceSession::Impl::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - // This lock is held for the duration of the function since we need to - // be sure that the action doesn't change state while we are checking the - // different data structures - - RwLock::ExclusiveLockScope _(m_ResultsLock); - - for (auto It = begin(m_ResultsMap), End = end(m_ResultsMap); It != End; ++It) - { - if (It->second->ActionId == ActionId) - { - OutResultPackage = std::move(It->second->GetResult()); - - m_ResultsMap.erase(It); - - return HttpResponseCode::OK; - } - } - - { - RwLock::SharedLockScope __(m_PendingLock); - - for (const auto& [K, Pending] : m_PendingActions) - { - if (Pending->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - // Lock order is important here to avoid deadlocks, RwLock m_RunningLock must - // always be taken after m_ResultsLock if both are needed - - { - RwLock::SharedLockScope __(m_RunningLock); - - for (const auto& [K, v] : m_RunningMap) - { - if (v->ActionId == ActionId) - { - return HttpResponseCode::Accepted; - } - } - } - - return HttpResponseCode::NotFound; -} - -void -FunctionServiceSession::Impl::GetCompleted(CbWriter& Cbo) -{ - Cbo.BeginArray("completed"); - - m_ResultsLock.WithSharedLock([&] { - for (auto& Kv : m_ResultsMap) - { - Cbo << Kv.first; - } - }); - - Cbo.EndArray(); -} - -# define ZEN_BATCH_SCHEDULER 1 - -void -FunctionServiceSession::Impl::SchedulePendingActions() -{ - int ScheduledCount = 0; - size_t RunningCount = m_RunningLock.WithSharedLock([&] { return m_RunningMap.size(); }); - size_t PendingCount = m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); }); - size_t ResultCount = m_ResultsLock.WithSharedLock([&] { return m_ResultsMap.size(); }); - - static Stopwatch DumpRunningTimer; - - auto _ = MakeGuard([&] { - ZEN_INFO("scheduled {} pending actions. {} running ({} retired), {} still pending, {} results", - ScheduledCount, - RunningCount, - m_RetiredCount.load(), - PendingCount, - ResultCount); - - if (DumpRunningTimer.GetElapsedTimeMs() > 30000) - { - DumpRunningTimer.Reset(); - - std::set RunningList; - m_RunningLock.WithSharedLock([&] { - for (auto& [K, V] : m_RunningMap) - { - RunningList.insert(K); - } - }); - - ExtendableStringBuilder<1024> RunningString; - for (int i : RunningList) - { - if (RunningString.Size()) - { - RunningString << ", "; - } - - RunningString.Append(IntNum(i)); - } - - ZEN_INFO("running: {}", RunningString); - } - }); - -# if ZEN_BATCH_SCHEDULER - size_t Capacity = QueryCapacity(); - - if (!Capacity) - { - _.Dismiss(); - - return; - } - - std::vector> ActionsToSchedule; - - // Pull actions to schedule from the pending queue, we will try to submit these to the runner outside of the lock - - m_PendingLock.WithExclusiveLock([&] { - if (m_ShutdownRequested) - { - return; - } - - if (m_PendingActions.empty()) - { - return; - } - - size_t NumActionsToSchedule = std::min(Capacity, m_PendingActions.size()); - - auto PendingIt = m_PendingActions.begin(); - const auto PendingEnd = m_PendingActions.end(); - - while (NumActionsToSchedule && PendingIt != PendingEnd) - { - const Ref& Pending = PendingIt->second; - - switch (Pending->ActionState()) - { - case RunnerAction::State::Pending: - ActionsToSchedule.push_back(Pending); - break; - - case RunnerAction::State::Running: - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - break; - - default: - case RunnerAction::State::New: - ZEN_WARN("unexpected state {} for pending action {}", static_cast(Pending->ActionState()), Pending->ActionLsn); - break; - } - - ++PendingIt; - --NumActionsToSchedule; - } - - PendingCount = m_PendingActions.size(); - }); - - if (ActionsToSchedule.empty()) - { - _.Dismiss(); - return; - } - - ZEN_INFO("attempting schedule of {} pending actions", ActionsToSchedule.size()); - - auto SubmitResults = SubmitActions(ActionsToSchedule); - - // Move successfully scheduled actions to the running map and remove - // from pending queue. It's actually possible that by the time we get - // to this stage some of the actions may have already completed, so - // they should not always be added to the running map - - eastl::hash_set ScheduledActions; - - for (size_t i = 0; i < ActionsToSchedule.size(); ++i) - { - const Ref& Pending = ActionsToSchedule[i]; - const SubmitResult& SubResult = SubmitResults[i]; - - if (SubResult.IsAccepted) - { - ScheduledActions.insert(Pending->ActionLsn); - } - } - - ScheduledCount += (int)ActionsToSchedule.size(); - -# else - m_PendingLock.WithExclusiveLock([&] { - while (!m_PendingActions.empty()) - { - if (m_ShutdownRequested) - { - return; - } - - // Here it would be good if we could decide to pop immediately to avoid - // holding the lock while creating processes etc - const Ref& Pending = m_PendingActions.begin()->second; - FunctionRunner::SubmitResult SubResult = SubmitAction(Pending); - - if (SubResult.IsAccepted) - { - // Great, the job is being taken care of by the runner - - ZEN_DEBUG("action {} ({}) PENDING -> RUNNING", Pending->ActionId, Pending->ActionLsn); - - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({Pending->ActionLsn, Pending}); - - RunningCount = m_RunningMap.size(); - }); - - m_PendingActions.pop_front(); - - PendingCount = m_PendingActions.size(); - ++ScheduledCount; - } - else - { - // Runner could not accept the job, leave it on the pending queue - - return; - } - } - }); -# endif -} - -void -FunctionServiceSession::Impl::MonitorThreadFunction() -{ - SetCurrentThreadName("FunctionServiceSession_Monitor"); - - auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); - - do - { - int TimeoutMs = 1000; - - if (m_PendingLock.WithSharedLock([&] { return m_PendingActions.size(); })) - { - TimeoutMs = 100; - } - - const bool Timedout = m_SchedulingThreadEvent.Wait(TimeoutMs); - - if (m_SchedulingThreadEnabled == false) - { - return; - } - - HandleActionUpdates(); - - // Schedule pending actions - - SchedulePendingActions(); - - if (!Timedout) - { - m_SchedulingThreadEvent.Reset(); - } - } while (m_SchedulingThreadEnabled); -} - -void -FunctionServiceSession::Impl::PostUpdate(RunnerAction* Action) -{ - m_UpdatedActionsLock.WithExclusiveLock([&] { m_UpdatedActions.emplace_back(Action); }); -} - -void -FunctionServiceSession::Impl::HandleActionUpdates() -{ - std::vector> UpdatedActions; - - m_UpdatedActionsLock.WithExclusiveLock([&] { std::swap(UpdatedActions, m_UpdatedActions); }); - - std::unordered_set SeenLsn; - std::unordered_set RunningLsn; - - for (Ref& Action : UpdatedActions) - { - const int ActionLsn = Action->ActionLsn; - - if (auto [It, Inserted] = SeenLsn.insert(ActionLsn); Inserted) - { - switch (Action->ActionState()) - { - case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); - break; - - case RunnerAction::State::Running: - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - m_RunningMap.insert({ActionLsn, Action}); - m_PendingActions.erase(ActionLsn); - }); - }); - ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); - break; - - case RunnerAction::State::Completed: - case RunnerAction::State::Failed: - m_ResultsLock.WithExclusiveLock([&] { - m_ResultsMap[ActionLsn] = Action; - - m_PendingLock.WithExclusiveLock([&] { - m_RunningLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); - - m_ActionHistoryLock.WithExclusiveLock([&] { - ActionHistoryEntry Entry{.Lsn = ActionLsn, - .ActionId = Action->ActionId, - .WorkerId = Action->Worker.WorkerId, - .ActionDescriptor = Action->ActionObj, - .Succeeded = Action->ActionState() == RunnerAction::State::Completed}; - - std::copy(std::begin(Action->Timestamps), std::end(Action->Timestamps), std::begin(Entry.Timestamps)); - - m_ActionHistory.push_back(std::move(Entry)); - - if (m_ActionHistory.size() > m_HistoryLimit) - { - m_ActionHistory.pop_front(); - } - }); - }); - m_RetiredCount.fetch_add(1); - m_ResultRate.Mark(1); - ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", - Action->ActionId, - ActionLsn, - Action->ActionState() == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - break; - } - } - } -} - -size_t -FunctionServiceSession::Impl::QueryCapacity() -{ - return m_LocalRunnerGroup.QueryCapacity() + m_RemoteRunnerGroup.QueryCapacity(); -} - -std::vector -FunctionServiceSession::Impl::SubmitActions(const std::vector>& Actions) -{ - std::vector Results; - - for (const Ref& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -////////////////////////////////////////////////////////////////////////// - -FunctionServiceSession::FunctionServiceSession(ChunkResolver& InChunkResolver) -{ - m_Impl = std::make_unique(this, InChunkResolver); -} - -FunctionServiceSession::~FunctionServiceSession() -{ - Shutdown(); -} - -bool -FunctionServiceSession::IsHealthy() -{ - return m_Impl->IsHealthy(); -} - -void -FunctionServiceSession::Shutdown() -{ - m_Impl->Shutdown(); -} - -void -FunctionServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) -{ - m_Impl->StartRecording(InResolver, RecordingPath); -} - -void -FunctionServiceSession::StopRecording() -{ - m_Impl->StopRecording(); -} - -void -FunctionServiceSession::EmitStats(CbObjectWriter& Cbo) -{ - m_Impl->EmitStats(Cbo); -} - -std::vector -FunctionServiceSession::GetKnownWorkerIds() -{ - return m_Impl->GetKnownWorkerIds(); -} - -WorkerDesc -FunctionServiceSession::GetWorkerDescriptor(const IoHash& WorkerId) -{ - return m_Impl->GetWorkerDescriptor(WorkerId); -} - -void -FunctionServiceSession::AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath) -{ - m_Impl->m_LocalRunnerGroup.AddRunner(new LocalProcessRunner(InChunkResolver, BasePath)); -} - -void -FunctionServiceSession::AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName) -{ - m_Impl->m_RemoteRunnerGroup.AddRunner(new RemoteHttpRunner(InChunkResolver, BasePath, HostName)); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueAction(CbObject ActionObject, int Priority) -{ - return m_Impl->EnqueueAction(ActionObject, Priority); -} - -FunctionServiceSession::EnqueueResult -FunctionServiceSession::EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int RequestPriority) -{ - return m_Impl->EnqueueResolvedAction(Worker, ActionObj, RequestPriority); -} - -void -FunctionServiceSession::RegisterWorker(CbPackage Worker) -{ - m_Impl->RegisterWorker(Worker); -} - -HttpResponseCode -FunctionServiceSession::GetActionResult(int ActionLsn, CbPackage& OutResultPackage) -{ - return m_Impl->GetActionResult(ActionLsn, OutResultPackage); -} - -HttpResponseCode -FunctionServiceSession::FindActionResult(const IoHash& ActionId, CbPackage& OutResultPackage) -{ - return m_Impl->FindActionResult(ActionId, OutResultPackage); -} - -std::vector -FunctionServiceSession::GetActionHistory(int Limit) -{ - return m_Impl->GetActionHistory(Limit); -} - -void -FunctionServiceSession::GetCompleted(CbWriter& Cbo) -{ - m_Impl->GetCompleted(Cbo); -} - -void -FunctionServiceSession::PostUpdate(RunnerAction* Action) -{ - m_Impl->PostUpdate(Action); -} - -////////////////////////////////////////////////////////////////////////// - -void -function_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp new file mode 100644 index 000000000..e82a40781 --- /dev/null +++ b/src/zencompute/httpcomputeservice.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/httpcomputeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "runners/functionrunner.h" + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include + +using namespace std::literals; + +namespace zen::compute { + +constinit AsciiSet g_DecimalSet("0123456789"); +constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); + +auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; +auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; +auto OidMatcher = [](std::string_view Str) { return Str.size() == 24 && AsciiSet::HasOnly(Str, g_HexSet); }; + +////////////////////////////////////////////////////////////////////////// + +struct HttpComputeService::Impl +{ + HttpComputeService* m_Self; + CidStore& m_CidStore; + IHttpStatsService& m_StatsService; + LoggerRef m_Log; + std::filesystem::path m_BaseDir; + HttpRequestRouter m_Router; + ComputeServiceSession m_ComputeService; + SystemMetricsTracker m_MetricsTracker; + + // Metrics + + metrics::OperationTiming m_HttpRequests; + + // Per-remote-queue metadata, shared across all lookup maps below. + + struct RemoteQueueInfo : RefCounted + { + int QueueId = 0; + Oid Token; + std::string IdempotencyKey; // empty if no idempotency key was provided + std::string ClientHostname; // empty if no hostname was provided + }; + + // Remote queue registry — all three maps share the same RemoteQueueInfo objects. + // All maps are guarded by m_RemoteQueueLock. + + RwLock m_RemoteQueueLock; + std::unordered_map, Oid::Hasher> m_RemoteQueuesByToken; // Token → info + std::unordered_map> m_RemoteQueuesByQueueId; // QueueId → info + std::unordered_map> m_RemoteQueuesByTag; // idempotency key → info + + LoggerRef Log() { return m_Log; } + + int ResolveQueueToken(const Oid& Token); + int ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture); + + struct IngestStats + { + int Count = 0; + int NewCount = 0; + uint64_t Bytes = 0; + uint64_t NewBytes = 0; + }; + + IngestStats IngestPackageAttachments(const CbPackage& Package); + bool CheckAttachments(const CbObject& ActionObj, std::vector& NeedList); + void HandleWorkersGet(HttpServerRequest& HttpReq); + void HandleWorkersAllGet(HttpServerRequest& HttpReq); + void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); + void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + + void RegisterRoutes(); + + Impl(HttpComputeService* Self, + CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) + : m_Self(Self) + , m_CidStore(InCidStore) + , m_StatsService(StatsService) + , m_Log(logging::Get("compute")) + , m_BaseDir(BaseDir) + , m_ComputeService(InCidStore) + { + m_ComputeService.AddLocalRunner(InCidStore, m_BaseDir / "local", MaxConcurrentActions); + m_ComputeService.WaitUntilReady(); + m_StatsService.RegisterHandler("compute", *m_Self); + RegisterRoutes(); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::RegisterRoutes() +{ + m_Router.AddMatcher("lsn", DecimalMatcher); + m_Router.AddMatcher("worker", IoHashMatcher); + m_Router.AddMatcher("action", IoHashMatcher); + m_Router.AddMatcher("queue", DecimalMatcher); + m_Router.AddMatcher("oidtoken", OidMatcher); + m_Router.AddMatcher("queueref", [](std::string_view Str) { return DecimalMatcher(Str) || OidMatcher(Str); }); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (m_ComputeService.IsHealthy()) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); + } + + return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "abandon", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + + if (Success) + { + CbObjectWriter Cbo; + Cbo << "state"sv + << "Abandoned"sv; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + CbObjectWriter Cbo; + Cbo << "error"sv + << "Cannot transition to Abandoned from current state"sv; + return HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers", + [this](HttpRouterRequest& Req) { HandleWorkersGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { HandleWorkerRequest(Req.ServerRequest(), IoHash::FromHexString(Req.GetCapture(1))); }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + CbObjectWriter Cbo; + m_ComputeService.GetCompleted(Cbo); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + Cbo.BeginObject("metrics"); + Describe(Sm, Cbo); + Cbo.EndObject(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetActionHistory(QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Running = m_ComputeService.GetRunningActions(); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt(Req.GetCapture(1)).value_or(0); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + // Once we've initiated the response we can mark the result + // as retired, allowing the service to free any associated + // resources. Note that there still needs to be a delay + // to allow the transmission to complete, it would be better + // if we could issue this once the response is fully sent... + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the + // one which uses the scheduled action lsn for lookups + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + CbPackage Output; + if (HttpResponseCode ResponseCode = m_ComputeService.FindActionResult(ActionId, /* out */ Output); + ResponseCode != HttpResponseCode::OK) + { + ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) + + if (ResponseCode == HttpResponseCode::NotFound) + { + return HttpReq.WriteResponse(ResponseCode); + } + + return HttpReq.WriteResponse(ResponseCode); + } + + ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + + return; + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto QueryParams = HttpReq.GetQueryParams(); + + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + // Resolve worker + + // + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything, enqueue the action for execution + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) + { + ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + Result.Lsn, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + // Could not resolve? + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + return; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "workers/all", + [this](HttpRouterRequest& Req) { HandleWorkersAllGet(Req.ServerRequest()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/all", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkersAllGet(HttpReq); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + if (ResolveQueueRef(HttpReq, Req.GetCapture(1)) == 0) + return; + HandleWorkerRequest(HttpReq, IoHash::FromHexString(Req.GetCapture(2))); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "sysinfo", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + CbObjectWriter Cbo; + Describe(Sm, Cbo); + + Cbo << "cpu_usage" << Sm.CpuUsagePercent; + Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + Cbo << "disk_used" << 100 * 1024; + Cbo << "disk_total" << 100 * 1024 * 1024; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "record/start", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StartRecording(m_CidStore, m_BaseDir / "recording"); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "record/stop", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + m_ComputeService.StopRecording(); + + return HttpReq.WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kPost); + + // Local-only queue listing and creation + + m_Router.RegisterRoute( + "queues", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + if (!HttpReq.IsLocalMachineRequest()) + { + return HttpReq.WriteResponse(HttpResponseCode::Forbidden); + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObjectWriter Cbo; + Cbo.BeginArray("queues"sv); + + for (const int QueueId : m_ComputeService.GetQueueIds()) + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + continue; + } + + Cbo.BeginObject(); + WriteQueueDescription(Cbo, QueueId, Status); + Cbo.EndObject(); + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kPost: + { + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + ComputeServiceSession::CreateQueueResult Result = + m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + + CbObjectWriter Cbo; + Cbo << "queue_id"sv << Result.QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + // Queue creation routes — these remain separate since local creates a plain queue + // while remote additionally generates an OID token for external access. + + m_Router.RegisterRoute( + "queues/remote", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // Extract optional fields from the request body. + // idempotency_key: when present, we return the existing remote queue token for this + // key rather than creating a new queue, making the endpoint safe to call concurrently. + // hostname: human-readable origin context stored alongside the queue for diagnostics. + // metadata: arbitrary CbObject metadata propagated from the originating queue. + // config: arbitrary CbObject config propagated from the originating queue. + std::string IdempotencyKey; + std::string ClientHostname; + CbObject Metadata; + CbObject Config; + if (const CbObject Body = HttpReq.ReadPayloadObject()) + { + IdempotencyKey = std::string(Body["idempotency_key"sv].AsString()); + ClientHostname = std::string(Body["hostname"sv].AsString()); + Metadata = Body.Find("metadata"sv).AsObject(); + Config = Body.Find("config"sv).AsObject(); + } + + // Stamp the forwarding node's hostname into the metadata so that the + // remote side knows which node originated the queue. + if (!ClientHostname.empty()) + { + CbObjectWriter MetaWriter; + for (auto Field : Metadata) + { + MetaWriter.AddField(Field.GetName(), Field); + } + MetaWriter << "via"sv << ClientHostname; + Metadata = MetaWriter.Save(); + } + + RwLock::ExclusiveLockScope _(m_RemoteQueueLock); + + if (!IdempotencyKey.empty()) + { + if (auto It = m_RemoteQueuesByTag.find(IdempotencyKey); It != m_RemoteQueuesByTag.end()) + { + Ref Existing = It->second; + if (m_ComputeService.GetQueueStatus(Existing->QueueId).IsValid) + { + CbObjectWriter Cbo; + Cbo << "queue_token"sv << Existing->Token.ToString(); + Cbo << "queue_id"sv << Existing->QueueId; + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // Queue has since expired — clean up stale entries and fall through to create a new one + m_RemoteQueuesByToken.erase(Existing->Token); + m_RemoteQueuesByQueueId.erase(Existing->QueueId); + m_RemoteQueuesByTag.erase(It); + } + } + + ComputeServiceSession::CreateQueueResult Result = m_ComputeService.CreateQueue({}, std::move(Metadata), std::move(Config)); + Ref InfoRef(new RemoteQueueInfo()); + InfoRef->QueueId = Result.QueueId; + InfoRef->Token = Oid::NewOid(); + InfoRef->IdempotencyKey = std::move(IdempotencyKey); + InfoRef->ClientHostname = std::move(ClientHostname); + + m_RemoteQueuesByToken[InfoRef->Token] = InfoRef; + m_RemoteQueuesByQueueId[InfoRef->QueueId] = InfoRef; + if (!InfoRef->IdempotencyKey.empty()) + { + m_RemoteQueuesByTag[InfoRef->IdempotencyKey] = InfoRef; + } + + CbObjectWriter Cbo; + Cbo << "queue_token"sv << InfoRef->Token.ToString(); + Cbo << "queue_id"sv << InfoRef->QueueId; + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + // Unified queue routes — {queueref} accepts both local integer IDs and remote OID tokens. + // ResolveQueueRef() handles access control (local-only for integer IDs) and token resolution. + + m_Router.RegisterRoute( + "queues/{queueref}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + case HttpVerb::kDelete: + { + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.CancelQueue(QueueId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "queues/{queueref}/drain", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + m_ComputeService.DrainQueue(QueueId); + + // Return updated queue status + Status = m_ComputeService.GetQueueStatus(QueueId); + + CbObjectWriter Cbo; + WriteQueueDescription(Cbo, QueueId, Status); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/completed", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Cbo; + m_ComputeService.GetQueueCompleted(QueueId, Cbo); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + ComputeServiceSession::QueueStatus Status = m_ComputeService.GetQueueStatus(QueueId); + + if (!Status.IsValid) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = HttpReq.GetQueryParams(); + + int QueryLimit = 50; + + if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) + { + QueryLimit = ParseInt(LimitParam).value_or(50); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("history"); + for (const auto& Entry : m_ComputeService.GetQueueHistory(QueueId, QueryLimit)) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Entry.Lsn; + Cbo << "queueId"sv << Entry.QueueId; + Cbo << "actionId"sv << Entry.ActionId; + Cbo << "workerId"sv << Entry.WorkerId; + Cbo << "succeeded"sv << Entry.Succeeded; + if (Entry.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Entry.CpuSeconds); + } + if (Entry.RetryCount > 0) + { + Cbo << "retry_count"sv << Entry.RetryCount; + } + + for (const auto& Timestamp : Entry.Timestamps) + { + Cbo.AddInteger( + fmt::format("time_{}"sv, RunnerAction::ToString(static_cast(&Timestamp - Entry.Timestamps))), + Timestamp); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/running", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + if (QueueId == 0) + { + return; + } + // Filter global running list to this queue + auto AllRunning = m_ComputeService.GetRunningActions(); + std::vector Running; + for (auto& Info : AllRunning) + if (Info.QueueId == QueueId) + Running.push_back(Info); + CbObjectWriter Cbo; + Cbo.BeginArray("running"); + for (const auto& Info : Running) + { + Cbo.BeginObject(); + Cbo << "lsn"sv << Info.Lsn; + Cbo << "queueId"sv << Info.QueueId; + Cbo << "actionId"sv << Info.ActionId; + if (Info.CpuUsagePercent >= 0.0f) + { + Cbo.AddFloat("cpuUsage"sv, Info.CpuUsagePercent); + } + if (Info.CpuSeconds > 0.0f) + { + Cbo.AddFloat("cpuSeconds"sv, Info.CpuSeconds); + } + Cbo.EndObject(); + } + Cbo.EndArray(); + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(2)); + WorkerDesc Worker = m_ComputeService.GetWorkerDescriptor(WorkerId); + + if (!Worker) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + ActionObj.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + + if (QueueId == 0) + { + return; + } + + const auto QueryParams = Req.ServerRequest().GetQueryParams(); + int RequestPriority = -1; + + if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) + { + RequestPriority = ParseInt(PriorityParam).value_or(-1); + } + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject ActionObj = LoadCompactBinaryObject(Payload); + + std::vector NeedList; + + if (!CheckAttachments(ActionObj, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + IngestStats Stats = IngestPackageAttachments(Action); + + if (ComputeServiceSession::EnqueueResult Result = + m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) + { + ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } + } + + default: + break; + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + ZEN_UNUSED(QueueId); + + CbPackage Output; + HttpResponseCode ResponseCode = m_ComputeService.GetActionResult(ActionLsn, Output); + + if (ResponseCode == HttpResponseCode::OK) + { + HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + else + { + HttpReq.WriteResponse(ResponseCode); + } + + m_ComputeService.RetireActionResult(ActionLsn); + } + break; + + case HttpVerb::kPost: + { + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RescheduleAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "lsn"sv << ActionLsn; + Cbo << "retry_count"sv << Result.RetryCount; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); +} + +////////////////////////////////////////////////////////////////////////// + +HttpComputeService::HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions) +: m_Impl(std::make_unique(this, InCidStore, StatsService, BaseDir, MaxConcurrentActions)) +{ +} + +HttpComputeService::~HttpComputeService() +{ + m_Impl->m_StatsService.UnregisterHandler("compute", *this); +} + +void +HttpComputeService::Shutdown() +{ + m_Impl->m_ComputeService.Shutdown(); +} + +ComputeServiceSession::ActionCounts +HttpComputeService::GetActionCounts() +{ + return m_Impl->m_ComputeService.GetActionCounts(); +} + +const char* +HttpComputeService::BaseUri() const +{ + return "/compute/"; +} + +void +HttpComputeService::HandleRequest(HttpServerRequest& Request) +{ + ZEN_TRACE_CPU("HttpComputeService::HandleRequest"); + metrics::OperationTiming::Scope $(m_Impl->m_HttpRequests); + + if (m_Impl->m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpComputeService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + m_Impl->m_ComputeService.EmitStats(Cbo); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpComputeService::Impl::WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status) +{ + Cbo << "queue_id"sv << Status.QueueId; + Cbo << "active_count"sv << Status.ActiveCount; + Cbo << "completed_count"sv << Status.CompletedCount; + Cbo << "failed_count"sv << Status.FailedCount; + Cbo << "abandoned_count"sv << Status.AbandonedCount; + Cbo << "cancelled_count"sv << Status.CancelledCount; + Cbo << "state"sv << ToString(Status.State); + Cbo << "cancelled"sv << (Status.State == ComputeServiceSession::QueueState::Cancelled); + Cbo << "draining"sv << (Status.State == ComputeServiceSession::QueueState::Draining); + Cbo << "is_complete"sv << Status.IsComplete; + + if (CbObject Meta = m_ComputeService.GetQueueMetadata(QueueId)) + { + Cbo << "metadata"sv << Meta; + } + + if (CbObject Cfg = m_ComputeService.GetQueueConfig(QueueId)) + { + Cbo << "config"sv << Cfg; + } + + { + RwLock::SharedLockScope $(m_RemoteQueueLock); + if (auto It = m_RemoteQueuesByQueueId.find(QueueId); It != m_RemoteQueuesByQueueId.end()) + { + Cbo << "queue_token"sv << It->second->Token.ToString(); + if (!It->second->ClientHostname.empty()) + { + Cbo << "hostname"sv << It->second->ClientHostname; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +int +HttpComputeService::Impl::ResolveQueueToken(const Oid& Token) +{ + RwLock::SharedLockScope $(m_RemoteQueueLock); + + auto It = m_RemoteQueuesByToken.find(Token); + + if (It != m_RemoteQueuesByToken.end()) + { + return It->second->QueueId; + } + + return 0; +} + +int +HttpComputeService::Impl::ResolveQueueRef(HttpServerRequest& HttpReq, std::string_view Capture) +{ + if (OidMatcher(Capture)) + { + // Remote OID token — accessible from any client + const Oid Token = Oid::FromHexString(Capture); + const int QueueId = ResolveQueueToken(Token); + + if (QueueId == 0) + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + return QueueId; + } + + // Local integer queue ID — restricted to local machine requests + if (!HttpReq.IsLocalMachineRequest()) + { + HttpReq.WriteResponse(HttpResponseCode::Forbidden); + return 0; + } + + return ParseInt(Capture).value_or(0); +} + +HttpComputeService::Impl::IngestStats +HttpComputeService::Impl::IngestPackageAttachments(const CbPackage& Package) +{ + IngestStats Stats; + + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + Stats.Bytes += CompressedSize; + ++Stats.Count; + + const CidStore::InsertResult InsertResult = m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + Stats.NewBytes += CompressedSize; + ++Stats.NewCount; + } + } + + return Stats; +} + +bool +HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vector& NeedList) +{ + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + return NeedList.empty(); +} + +void +HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const IoHash& WorkerId : m_ComputeService.GetKnownWorkerIds()) + { + Cbo << WorkerId; + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkersAllGet(HttpServerRequest& HttpReq) +{ + std::vector WorkerIds = m_ComputeService.GetKnownWorkerIds(); + + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + for (const IoHash& WorkerId : WorkerIds) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "descriptor" << m_ComputeService.GetWorkerDescriptor(WorkerId).Descriptor.GetObject(); + Cbo.EndObject(); + } + + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId) +{ + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + if (WorkerDesc Desc = m_ComputeService.GetWorkerDescriptor(WorkerId)) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); + } + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject WorkerSpec = HttpReq.ReadPayloadObject(); + + HashKeySet ChunkSet; + WorkerSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerSpec); + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + m_ComputeService.RegisterWorker(WorkerPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + ResponseWriter.AddHash(Hash); + }); + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); + CbObject WorkerSpec = WorkerSpecPackage.GetObject(); + + std::span Attachments = WorkerSpecPackage.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + m_ComputeService.RegisterWorker(WorkerSpecPackage); + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } +} + +////////////////////////////////////////////////////////////////////////// + +void +httpcomputeservice_forcelink() +{ +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httpfunctionservice.cpp b/src/zencompute/httpfunctionservice.cpp deleted file mode 100644 index 09a9684a7..000000000 --- a/src/zencompute/httpfunctionservice.cpp +++ /dev/null @@ -1,709 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/httpfunctionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include - -using namespace std::literals; - -namespace zen::compute { - -constinit AsciiSet g_DecimalSet("0123456789"); -auto DecimalMatcher = [](std::string_view Str) { return AsciiSet::HasOnly(Str, g_DecimalSet); }; - -constinit AsciiSet g_HexSet("0123456789abcdefABCDEF"); -auto IoHashMatcher = [](std::string_view Str) { return Str.size() == 40 && AsciiSet::HasOnly(Str, g_HexSet); }; - -HttpFunctionService::HttpFunctionService(CidStore& InCidStore, - IHttpStatsService& StatsService, - [[maybe_unused]] const std::filesystem::path& BaseDir) -: m_CidStore(InCidStore) -, m_StatsService(StatsService) -, m_Log(logging::Get("apply")) -, m_BaseDir(BaseDir) -, m_FunctionService(InCidStore) -{ - m_FunctionService.AddLocalRunner(InCidStore, m_BaseDir / "local"); - - m_StatsService.RegisterHandler("apply", *this); - - m_Router.AddMatcher("lsn", DecimalMatcher); - m_Router.AddMatcher("worker", IoHashMatcher); - m_Router.AddMatcher("action", IoHashMatcher); - - m_Router.RegisterRoute( - "ready", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - if (m_FunctionService.IsHealthy()) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "ok"); - } - - return HttpReq.WriteResponse(HttpResponseCode::ServiceUnavailable); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"sv); - for (const IoHash& WorkerId : m_FunctionService.GetKnownWorkerIds()) - { - Cbo << WorkerId; - } - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "workers/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - if (WorkerDesc Desc = m_FunctionService.GetWorkerDescriptor(WorkerId)) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor.GetObject()); - } - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - - case HttpVerb::kPost: - { - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - CbObject WorkerSpec = HttpReq.ReadPayloadObject(); - - // Determine which pieces are missing and need to be transmitted - - HashKeySet ChunkSet; - - WorkerSpec.IterateAttachments([&](CbFieldView Field) { - const IoHash Hash = Field.AsHash(); - ChunkSet.AddHashToSet(Hash); - }); - - CbPackage WorkerPackage; - WorkerPackage.SetObject(WorkerSpec); - - m_CidStore.FilterChunks(ChunkSet); - - if (ChunkSet.IsEmpty()) - { - ZEN_DEBUG("worker {}: all attachments already available", WorkerId); - m_FunctionService.RegisterWorker(WorkerPackage); - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - - CbObjectWriter ResponseWriter; - ResponseWriter.BeginArray("need"); - - ChunkSet.IterateHashes([&](const IoHash& Hash) { - ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); - ResponseWriter.AddHash(Hash); - }); - - ResponseWriter.EndArray(); - - ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage WorkerSpecPackage = HttpReq.ReadPayloadPackage(); - CbObject WorkerSpec = WorkerSpecPackage.GetObject(); - - std::span Attachments = WorkerSpecPackage.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer Buffer = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - TotalAttachmentBytes += Buffer.GetCompressedSize(); - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += Buffer.GetCompressedSize(); - ++NewAttachmentCount; - } - } - - ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", - WorkerId, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - m_FunctionService.RegisterWorker(WorkerSpecPackage); - - return HttpReq.WriteResponse(HttpResponseCode::NoContent); - } - break; - - default: - break; - } - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/completed", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - CbObjectWriter Cbo; - m_FunctionService.GetCompleted(Cbo); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - Cbo.BeginObject("metrics"); - Describe(Sm, Cbo); - Cbo.EndObject(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/history", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const auto QueryParams = HttpReq.GetQueryParams(); - - int QueryLimit = 50; - - if (auto LimitParam = QueryParams.GetValue("limit"); LimitParam.empty() == false) - { - QueryLimit = ParseInt(LimitParam).value_or(50); - } - - CbObjectWriter Cbo; - Cbo.BeginArray("history"); - for (const auto& Entry : m_FunctionService.GetActionHistory(QueryLimit)) - { - Cbo.BeginObject(); - Cbo << "lsn"sv << Entry.Lsn; - Cbo << "actionId"sv << Entry.ActionId; - Cbo << "workerId"sv << Entry.WorkerId; - Cbo << "succeeded"sv << Entry.Succeeded; - Cbo << "actionDescriptor"sv << Entry.ActionDescriptor; - - for (const auto& Timestamp : Entry.Timestamps) - { - Cbo.AddInteger( - fmt::format("time_{}"sv, RunnerAction::ToString(static_cast(&Timestamp - Entry.Timestamps))), - Timestamp); - } - Cbo.EndObject(); - } - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{lsn}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const int ActionLsn = std::stoi(std::string{Req.GetCapture(1)}); - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - { - CbPackage Output; - HttpResponseCode ResponseCode = m_FunctionService.GetActionResult(ActionLsn, Output); - - if (ResponseCode == HttpResponseCode::OK) - { - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - } - - return HttpReq.WriteResponse(ResponseCode); - } - break; - - case HttpVerb::kPost: - { - // Add support for cancellation, priority changes - } - break; - - default: - break; - } - }, - HttpVerb::kGet | HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the - // one which uses the scheduled action lsn for lookups - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); - - CbPackage Output; - if (HttpResponseCode ResponseCode = m_FunctionService.FindActionResult(ActionId, /* out */ Output); - ResponseCode != HttpResponseCode::OK) - { - ZEN_TRACE("jobs/{}/{}: {}", Req.GetCapture(1), Req.GetCapture(2), ToString(ResponseCode)) - - if (ResponseCode == HttpResponseCode::NotFound) - { - return HttpReq.WriteResponse(ResponseCode); - } - - return HttpReq.WriteResponse(ResponseCode); - } - - ZEN_DEBUG("jobs/{}/{}: OK", Req.GetCapture(1), Req.GetCapture(2)) - - return HttpReq.WriteResponse(HttpResponseCode::OK, Output); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "jobs/{worker}", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); - - WorkerDesc Worker = m_FunctionService.GetWorkerDescriptor(WorkerId); - - if (!Worker) - { - return HttpReq.WriteResponse(HttpResponseCode::NotFound); - } - - const auto QueryParams = Req.ServerRequest().GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt(PriorityParam).value_or(-1); - } - - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = - m_FunctionService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "jobs", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - const auto QueryParams = HttpReq.GetQueryParams(); - - int RequestPriority = -1; - - if (auto PriorityParam = QueryParams.GetValue("priority"); PriorityParam.empty() == false) - { - RequestPriority = ParseInt(PriorityParam).value_or(-1); - } - - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (FunctionServiceSession::EnqueueResult Result = m_FunctionService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "workers/all", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - std::vector WorkerIds = m_FunctionService.GetKnownWorkerIds(); - - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); - - for (const IoHash& WorkerId : WorkerIds) - { - Cbo.BeginObject(); - - Cbo << "id" << WorkerId; - - const auto& Descriptor = m_FunctionService.GetWorkerDescriptor(WorkerId); - - Cbo << "descriptor" << Descriptor.Descriptor.GetObject(); - - Cbo.EndObject(); - } - - Cbo.EndArray(); - - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "sysinfo", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - SystemMetrics Sm = GetSystemMetricsForReporting(); - - CbObjectWriter Cbo; - Describe(Sm, Cbo); - - Cbo << "cpu_usage" << Sm.CpuUsagePercent; - Cbo << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; - Cbo << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; - Cbo << "disk_used" << 100 * 1024; - Cbo << "disk_total" << 100 * 1024 * 1024; - - return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "record/start", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StartRecording(m_CidStore, m_BaseDir / "recording"); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); - - m_Router.RegisterRoute( - "record/stop", - [this](HttpRouterRequest& Req) { - HttpServerRequest& HttpReq = Req.ServerRequest(); - - m_FunctionService.StopRecording(); - - return HttpReq.WriteResponse(HttpResponseCode::OK); - }, - HttpVerb::kPost); -} - -HttpFunctionService::~HttpFunctionService() -{ - m_StatsService.UnregisterHandler("apply", *this); -} - -void -HttpFunctionService::Shutdown() -{ - m_FunctionService.Shutdown(); -} - -const char* -HttpFunctionService::BaseUri() const -{ - return "/apply/"; -} - -void -HttpFunctionService::HandleRequest(HttpServerRequest& Request) -{ - metrics::OperationTiming::Scope $(m_HttpRequests); - - if (m_Router.HandleRequest(Request) == false) - { - ZEN_WARN("No route found for {0}", Request.RelativeUri()); - } -} - -void -HttpFunctionService::HandleStatsRequest(HttpServerRequest& Request) -{ - CbObjectWriter Cbo; - m_FunctionService.EmitStats(Cbo); - - Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); -} - -////////////////////////////////////////////////////////////////////////// - -void -httpfunction_forcelink() -{ -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/httporchestrator.cpp b/src/zencompute/httporchestrator.cpp index 39e7e60d7..6cbe01e04 100644 --- a/src/zencompute/httporchestrator.cpp +++ b/src/zencompute/httporchestrator.cpp @@ -2,65 +2,398 @@ #include "zencompute/httporchestrator.h" -#include -#include +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include namespace zen::compute { -HttpOrchestratorService::HttpOrchestratorService() : m_Log(logging::Get("orch")) +// Worker IDs must be 3-64 characters and can only contain letters, numbers, underscores, and dashes +static bool +IsValidWorkerId(std::string_view Id) +{ + if (Id.size() < 3 || Id.size() > 64) + { + return false; + } + for (char c : Id) + { + if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-') + { + continue; + } + return false; + } + return true; +} + +// Shared announce payload parser used by both the HTTP POST route and the +// WebSocket message handler. Returns the worker ID on success (empty on +// validation failure). The returned WorkerAnnouncement has string_view +// fields that reference the supplied CbObjectView, so the CbObject must +// outlive the returned announcement. +static std::string_view +ParseWorkerAnnouncement(const CbObjectView& Data, OrchestratorService::WorkerAnnouncement& Ann) { + Ann.Id = Data["id"].AsString(""); + Ann.Uri = Data["uri"].AsString(""); + + if (!IsValidWorkerId(Ann.Id)) + { + return {}; + } + + if (!Ann.Uri.starts_with("http://") && !Ann.Uri.starts_with("https://")) + { + return {}; + } + + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Platform = Data["platform"].AsString(""); + Ann.CpuUsagePercent = Data["cpu_usage"].AsFloat(0.0f); + Ann.MemoryTotalBytes = Data["memory_total"].AsUInt64(0); + Ann.MemoryUsedBytes = Data["memory_used"].AsUInt64(0); + Ann.BytesReceived = Data["bytes_received"].AsUInt64(0); + Ann.BytesSent = Data["bytes_sent"].AsUInt64(0); + Ann.ActionsPending = Data["actions_pending"].AsInt32(0); + Ann.ActionsRunning = Data["actions_running"].AsInt32(0); + Ann.ActionsCompleted = Data["actions_completed"].AsInt32(0); + Ann.ActiveQueues = Data["active_queues"].AsInt32(0); + Ann.Provisioner = Data["provisioner"].AsString(""); + + if (auto Metrics = Data["metrics"].AsObjectView()) + { + Ann.Cpus = Metrics["lp_count"].AsInt32(0); + if (Ann.Cpus <= 0) + { + Ann.Cpus = 1; + } + } + + return Ann.Id; +} + +HttpOrchestratorService::HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_Service(std::make_unique(std::move(DataDir), EnableWorkerWebSocket)) +, m_Hostname(GetMachineName()) +{ + m_Router.AddMatcher("workerid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + m_Router.AddMatcher("clientid", [](std::string_view Segment) { return IsValidWorkerId(Segment); }); + + // dummy endpoint for websocket clients + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + Cbo << "hostname" << std::string_view(m_Hostname); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + m_Router.RegisterRoute( "provision", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "announce", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); - CbObjectWriter Cbo; - Cbo.BeginArray("workers"); + CbObject Data = HttpReq.ReadPayloadObject(); + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); - m_KnownWorkersLock.WithSharedLock([&] { - for (const auto& [WorkerId, Worker] : m_KnownWorkers) + if (WorkerId.empty()) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid worker announcement: id must be 3-64 alphanumeric/underscore/dash " + "characters and uri must start with http:// or https://"); + } + + m_Service->AnnounceWorker(Ann); + + HttpReq.WriteResponse(HttpResponseCode::OK); + +# if ZEN_WITH_WEBSOCKETS + // Notify push thread that state may have changed + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "agents", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetWorkerList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetProvisioningHistory(Limit)); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "timeline/{workerid}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + std::string_view WorkerId = Req.GetCapture(1); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + auto LimitStr = Params.GetValue("limit"); + + std::optional From; + std::optional To; + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt(FromStr); + if (!Val) { - Cbo.BeginObject(); - Cbo << "uri" << Worker.BaseUri; - Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); - Cbo.EndObject(); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } - }); + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + int Limit = !LimitStr.empty() ? zen::ParseInt(LimitStr).value_or(0) : 0; - Cbo.EndArray(); + CbObject Result = m_Service->GetWorkerTimeline(WorkerId, From, To, Limit); - HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + if (!Result) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); }, - HttpVerb::kPost); + HttpVerb::kGet); m_Router.RegisterRoute( - "announce", + "timeline", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + auto FromStr = Params.GetValue("from"); + auto ToStr = Params.GetValue("to"); + + DateTime From = DateTime(0); + DateTime To = DateTime::Now(); + + if (!FromStr.empty()) + { + auto Val = zen::ParseInt(FromStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + From = DateTime(*Val); + } + + if (!ToStr.empty()) + { + auto Val = zen::ParseInt(ToStr); + if (!Val) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + To = DateTime(*Val); + } + + CbObject Result = m_Service->GetAllTimelines(From, To); + + HttpReq.WriteResponse(HttpResponseCode::OK, std::move(Result)); + }, + HttpVerb::kGet); + + // Client tracking endpoints + + m_Router.RegisterRoute( + "clients", [this](HttpRouterRequest& Req) { HttpServerRequest& HttpReq = Req.ServerRequest(); CbObject Data = HttpReq.ReadPayloadObject(); - std::string_view WorkerId = Data["id"].AsString(""); - std::string_view WorkerUri = Data["uri"].AsString(""); + OrchestratorService::ClientAnnouncement Ann; + Ann.SessionId = Data["session_id"].AsObjectId(Oid::Zero); + Ann.Hostname = Data["hostname"].AsString(""); + Ann.Address = HttpReq.GetRemoteAddress(); - if (WorkerId.empty() || WorkerUri.empty()) + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) { - return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + Ann.Metadata = CbObject::Clone(MetadataView); } - m_KnownWorkersLock.WithExclusiveLock([&] { - auto& Worker = m_KnownWorkers[std::string(WorkerId)]; - Worker.BaseUri = WorkerUri; - Worker.LastSeen.Reset(); - }); + std::string ClientId = m_Service->AnnounceClient(Ann); - HttpReq.WriteResponse(HttpResponseCode::OK); + CbObjectWriter ResponseObj; + ResponseObj << "id" << std::string_view(ClientId); + HttpReq.WriteResponse(HttpResponseCode::OK, ResponseObj.Save()); + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif }, HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/update", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + CbObject MetadataObj; + CbObject Data = HttpReq.ReadPayloadObject(); + if (Data) + { + auto MetadataView = Data["metadata"].AsObjectView(); + if (MetadataView) + { + MetadataObj = CbObject::Clone(MetadataView); + } + } + + if (m_Service->UpdateClient(ClientId, std::move(MetadataObj))) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients/{clientid}/complete", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view ClientId = Req.GetCapture(1); + + if (m_Service->CompleteClient(ClientId)) + { + HttpReq.WriteResponse(HttpResponseCode::OK); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + +# if ZEN_WITH_WEBSOCKETS + m_PushEvent.Set(); +# endif + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "clients", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_Service->GetClientList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "clients/history", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + auto Params = HttpReq.GetQueryParams(); + + int Limit = 100; + auto LimitStr = Params.GetValue("limit"); + if (!LimitStr.empty()) + { + Limit = std::atoi(std::string(LimitStr).c_str()); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, m_Service->GetClientHistory(Limit)); + }, + HttpVerb::kGet); + +# if ZEN_WITH_WEBSOCKETS + + // Start the WebSocket push thread + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); +# endif } HttpOrchestratorService::~HttpOrchestratorService() { + Shutdown(); +} + +void +HttpOrchestratorService::Shutdown() +{ +# if ZEN_WITH_WEBSOCKETS + if (!m_PushEnabled.exchange(false)) + { + return; + } + + // Stop the push thread first, before touching connections. This ensures + // the push thread is no longer reading m_WsConnections or calling into + // m_Service when we start tearing things down. + m_PushEvent.Set(); + if (m_PushThread.joinable()) + { + m_PushThread.join(); + } + + // Clean up worker WebSocket connections — collect IDs under lock, then + // notify the service outside the lock to avoid lock-order inversions. + std::vector WorkerIds; + m_WorkerWsLock.WithExclusiveLock([&] { + WorkerIds.reserve(m_WorkerWsMap.size()); + for (const auto& [Conn, Id] : m_WorkerWsMap) + { + WorkerIds.push_back(Id); + } + m_WorkerWsMap.clear(); + }); + for (const auto& Id : WorkerIds) + { + m_Service->SetWorkerWebSocketConnected(Id, false); + } + + // Now that the push thread is gone, release all dashboard connections. + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); +# endif } const char* @@ -78,4 +411,240 @@ HttpOrchestratorService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::OnWebSocketOpen(Ref Connection) +{ + if (!m_PushEnabled.load()) + { + return; + } + + ZEN_INFO("WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Wake push thread to send initial state immediately + m_PushEvent.Set(); +} + +void +HttpOrchestratorService::OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) +{ + // Only handle binary messages from workers when the feature is enabled. + if (!m_Service->IsWorkerWebSocketEnabled() || Msg.Opcode != WebSocketOpcode::kBinary) + { + return; + } + + std::string WorkerId = HandleWorkerWebSocketMessage(Msg); + if (WorkerId.empty()) + { + return; + } + + // Check if this is a new worker WebSocket connection + bool IsNewWorkerWs = false; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It == m_WorkerWsMap.end()) + { + m_WorkerWsMap[&Conn] = WorkerId; + IsNewWorkerWs = true; + } + }); + + if (IsNewWorkerWs) + { + m_Service->SetWorkerWebSocketConnected(WorkerId, true); + } + + m_PushEvent.Set(); +} + +std::string +HttpOrchestratorService::HandleWorkerWebSocketMessage(const WebSocketMessage& Msg) +{ + // Workers send CbObject in native binary format over the WebSocket to + // avoid the lossy CbObject↔JSON round-trip. + CbObject Data = CbObject::MakeView(Msg.Payload.GetData()); + if (!Data) + { + ZEN_WARN("worker WebSocket message is not a valid CbObject"); + return {}; + } + + OrchestratorService::WorkerAnnouncement Ann; + std::string_view WorkerId = ParseWorkerAnnouncement(Data, Ann); + if (WorkerId.empty()) + { + ZEN_WARN("invalid worker announcement via WebSocket"); + return {}; + } + + m_Service->AnnounceWorker(Ann); + return std::string(WorkerId); +} + +void +HttpOrchestratorService::OnWebSocketClose(WebSocketConnection& Conn, + [[maybe_unused]] uint16_t Code, + [[maybe_unused]] std::string_view Reason) +{ + ZEN_INFO("WebSocket client disconnected (code {})", Code); + + // Check if this was a worker WebSocket connection; collect the ID under + // the worker lock, then notify the service outside the lock. + std::string DisconnectedWorkerId; + m_WorkerWsLock.WithExclusiveLock([&] { + auto It = m_WorkerWsMap.find(&Conn); + if (It != m_WorkerWsMap.end()) + { + DisconnectedWorkerId = std::move(It->second); + m_WorkerWsMap.erase(It); + } + }); + + if (!DisconnectedWorkerId.empty()) + { + m_Service->SetWorkerWebSocketConnected(DisconnectedWorkerId, false); + m_PushEvent.Set(); + } + + if (!m_PushEnabled.load()) + { + return; + } + + // Remove from dashboard connections + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} +# endif + +////////////////////////////////////////////////////////////////////////// +// +// Push thread +// + +# if ZEN_WITH_WEBSOCKETS +void +HttpOrchestratorService::PushThreadFunction() +{ + SetCurrentThreadName("orch_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(2000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + // Snapshot current connections + std::vector> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + continue; + } + + // Build combined JSON with worker list, provisioning history, clients, and client history + CbObject WorkerList = m_Service->GetWorkerList(); + CbObject History = m_Service->GetProvisioningHistory(50); + CbObject ClientList = m_Service->GetClientList(); + CbObject ClientHistory = m_Service->GetClientHistory(50); + + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + JsonBuilder.Append(fmt::format("\"hostname\":\"{}\",", m_Hostname)); + + // Emit workers array from worker list + ExtendableStringBuilder<2048> WorkerJson; + WorkerList.ToJson(WorkerJson); + std::string_view WorkerJsonView = WorkerJson.ToView(); + // Strip outer braces: {"workers":[...]} -> "workers":[...] + if (WorkerJsonView.size() >= 2) + { + JsonBuilder.Append(WorkerJsonView.substr(1, WorkerJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit events array from history + ExtendableStringBuilder<2048> HistoryJson; + History.ToJson(HistoryJson); + std::string_view HistoryJsonView = HistoryJson.ToView(); + if (HistoryJsonView.size() >= 2) + { + JsonBuilder.Append(HistoryJsonView.substr(1, HistoryJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit clients array from client list + ExtendableStringBuilder<2048> ClientJson; + ClientList.ToJson(ClientJson); + std::string_view ClientJsonView = ClientJson.ToView(); + if (ClientJsonView.size() >= 2) + { + JsonBuilder.Append(ClientJsonView.substr(1, ClientJsonView.size() - 2)); + } + + JsonBuilder.Append(","); + + // Emit client_events array from client history + ExtendableStringBuilder<2048> ClientHistoryJson; + ClientHistory.ToJson(ClientHistoryJson); + std::string_view ClientHistoryJsonView = ClientHistoryJson.ToView(); + if (ClientHistoryJsonView.size() >= 2) + { + JsonBuilder.Append(ClientHistoryJsonView.substr(1, ClientHistoryJsonView.size() - 2)); + } + + JsonBuilder.Append("}"); + std::string_view Json = JsonBuilder.ToView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } + } +} +# endif + } // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h new file mode 100644 index 000000000..a5bc5a34d --- /dev/null +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -0,0 +1,151 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +/** Snapshot of detected cloud instance properties. */ +struct CloudInstanceInfo +{ + CloudProvider Provider = CloudProvider::None; + std::string InstanceId; + std::string AvailabilityZone; + bool IsSpot = false; + bool IsAutoscaling = false; +}; + +/** + * Detects whether the process is running on a cloud VM (AWS, Azure, or GCP) + * and monitors for impending termination signals. + * + * Detection works by querying the Instance Metadata Service (IMDS) at the + * well-known link-local address 169.254.169.254, which is only routable from + * within a cloud VM. Each provider is probed in sequence (AWS -> Azure -> GCP); + * the first successful response wins. + * + * To avoid a ~200ms connect timeout penalty on every startup when running on + * bare-metal or non-cloud machines, failed probes write sentinel files + * (e.g. ".isNotAWS") to DataDir. Subsequent startups skip providers that have + * a sentinel present. Delete the sentinel files to force re-detection. + * + * When a provider is detected, a background thread polls for termination + * signals every 5 seconds (spot interruption, autoscaling lifecycle changes, + * scheduled maintenance). The termination state is exposed as an atomic bool + * so the compute server can include it in coordinator announcements and react + * to imminent shutdown. + * + * Thread safety: GetInstanceInfo() and GetTerminationReason() acquire a + * shared RwLock; the background monitor thread acquires the exclusive lock + * only when writing the termination reason (a one-time transition). The + * termination-pending flag itself is a lock-free atomic. + * + * Usage: + * auto Cloud = std::make_unique(DataDir / "cloud"); + * if (Cloud->IsTerminationPending()) { ... } + * Cloud->Describe(AnnounceBody); // writes "cloud" sub-object into CB + */ +class CloudMetadata +{ +public: + /** Synchronously probes cloud providers and starts the termination monitor + * if a provider is detected. Creates DataDir if it does not exist. + */ + explicit CloudMetadata(std::filesystem::path DataDir); + + /** Synchronously probes cloud providers at the given IMDS endpoint. + * Intended for testing — allows redirecting all IMDS queries to a local + * mock HTTP server instead of the real 169.254.169.254 endpoint. + */ + CloudMetadata(std::filesystem::path DataDir, std::string ImdsEndpoint); + + /** Stops the termination monitor thread and joins it. */ + ~CloudMetadata(); + + CloudMetadata(const CloudMetadata&) = delete; + CloudMetadata& operator=(const CloudMetadata&) = delete; + + CloudProvider GetProvider() const; + CloudInstanceInfo GetInstanceInfo() const; + bool IsTerminationPending() const; + std::string GetTerminationReason() const; + + /** Writes a "cloud" sub-object into the compact binary writer if a provider + * was detected. No-op when running on bare metal. + */ + void Describe(CbWriter& Writer) const; + + /** Executes a single termination-poll cycle for the detected provider. + * Public so tests can drive poll cycles synchronously without relying on + * the background thread's 5-second timer. + */ + void PollTermination(); + + /** Removes the negative-cache sentinel files (.isNotAWS, .isNotAzure, + * .isNotGCP) from DataDir so subsequent detection probes are not skipped. + * Primarily intended for tests that need to reset state between sub-cases. + */ + void ClearSentinelFiles(); + +private: + /** Tries each provider in order, stops on first successful detection. */ + void DetectProvider(); + bool TryDetectAWS(); + bool TryDetectAzure(); + bool TryDetectGCP(); + + void WriteSentinelFile(const std::filesystem::path& Path); + bool HasSentinelFile(const std::filesystem::path& Path) const; + + void StartTerminationMonitor(); + void TerminationMonitorThread(); + void PollAWSTermination(); + void PollAzureTermination(); + void PollGCPTermination(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + std::filesystem::path m_DataDir; + std::string m_ImdsEndpoint; + + mutable RwLock m_InfoLock; + CloudInstanceInfo m_Info; + + std::atomic m_TerminationPending{false}; + + mutable RwLock m_ReasonLock; + std::string m_TerminationReason; + + // IMDSv2 session token, acquired during AWS detection and reused for + // subsequent termination polling. Has a 300s TTL on the AWS side; if it + // expires mid-run the poll requests will get 401s which we treat as + // non-terminal (the monitor simply retries next cycle). + std::string m_AwsToken; + + std::thread m_MonitorThread; + std::atomic m_MonitorEnabled{true}; + Event m_MonitorEvent; +}; + +void cloudmetadata_forcelink(); // internal + +} // namespace zen::compute diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h new file mode 100644 index 000000000..65ec5f9ee --- /dev/null +++ b/src/zencompute/include/zencompute/computeservice.h @@ -0,0 +1,262 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include + +# include + +namespace zen { +class ChunkResolver; +class CbObjectWriter; +} // namespace zen + +namespace zen::compute { + +class ActionRecorder; +class ComputeServiceSession; +class IActionResultHandler; +class LocalProcessRunner; +class RemoteHttpRunner; +struct RunnerAction; +struct SubmitResult; + +struct WorkerDesc +{ + CbPackage Descriptor; + IoHash WorkerId{IoHash::Zero}; + + inline operator bool() const { return WorkerId != IoHash::Zero; } +}; + +/** + * Lambda style compute function service + * + * The responsibility of this class is to accept function execution requests, and + * schedule them using one or more FunctionRunner instances. It will basically always + * accept requests, queueing them if necessary, and then hand them off to runners + * as they become available. + * + * This is typically fronted by an API service that handles communication with clients. + */ +class ComputeServiceSession final +{ +public: + /** + * Session lifecycle state machine. + * + * Forward transitions: Created -> Ready -> Draining -> Paused -> Sunset + * Backward transitions: Draining -> Ready, Paused -> Ready + * Automatic transition: Draining -> Paused (when pending + running reaches 0) + * Jump transitions: any non-terminal -> Abandoned, any non-terminal -> Sunset + * Terminal states: Abandoned (only Sunset out), Sunset (no transitions out) + * + * | State | Accept new actions | Schedule pending | Finish running | + * |-----------|-------------------|-----------------|----------------| + * | Created | No | No | N/A | + * | Ready | Yes | Yes | Yes | + * | Draining | No | Yes | Yes | + * | Paused | No | No | No | + * | Abandoned | No | No | No (all abandoned) | + * | Sunset | No | No | No | + */ + enum class SessionState + { + Created, // Initial state before WaitUntilReady completes + Ready, // Normal operating state; accepts and schedules work + Draining, // Stops accepting new work; finishes existing; auto-transitions to Paused when empty + Paused, // Idle; no work accepted or scheduled; can resume to Ready + Abandoned, // Spot termination grace period; all actions abandoned; only Sunset out + Sunset // Terminal; triggers full shutdown + }; + + ComputeServiceSession(ChunkResolver& InChunkResolver); + ~ComputeServiceSession(); + + void WaitUntilReady(); + void Shutdown(); + bool IsHealthy(); + + SessionState GetSessionState() const; + + // Request a state transition. Returns false if the transition is invalid. + // Sunset can be reached from any non-Sunset state. + bool RequestStateTransition(SessionState NewState); + + // Orchestration + + void SetOrchestratorEndpoint(std::string_view Endpoint); + void SetOrchestratorBasePath(std::filesystem::path BasePath); + + // Worker registration and discovery + + void RegisterWorker(CbPackage Worker); + [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); + [[nodiscard]] std::vector GetKnownWorkerIds(); + + // Action runners + + void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, int32_t MaxConcurrentActions = 0); + void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); + + // Action submission + + struct EnqueueResult + { + int Lsn; + CbObject ResponseMessage; + + inline operator bool() const { return Lsn != 0; } + }; + + [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); + [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); + + // Queue management + // + // Queues group actions submitted by a single client session. They allow + // cancelling or polling completion of all actions in the group. + + struct CreateQueueResult + { + int QueueId = 0; // 0 if creation failed + }; + + enum class QueueState + { + Active, + Draining, + Cancelled, + }; + + struct QueueStatus + { + bool IsValid = false; + int QueueId = 0; + int ActiveCount = 0; // pending + running (not yet completed) + int CompletedCount = 0; // successfully completed + int FailedCount = 0; // failed + int AbandonedCount = 0; // abandoned + int CancelledCount = 0; // cancelled + QueueState State = QueueState::Active; + bool IsComplete = false; // ActiveCount == 0 + }; + + [[nodiscard]] CreateQueueResult CreateQueue(std::string_view Tag = {}, CbObject Metadata = {}, CbObject Config = {}); + [[nodiscard]] std::vector GetQueueIds(); + [[nodiscard]] QueueStatus GetQueueStatus(int QueueId); + [[nodiscard]] CbObject GetQueueMetadata(int QueueId); + [[nodiscard]] CbObject GetQueueConfig(int QueueId); + void CancelQueue(int QueueId); + void DrainQueue(int QueueId); + void DeleteQueue(int QueueId); + void GetQueueCompleted(int QueueId, CbWriter& Cbo); + + // Queue-scoped action submission. Actions submitted via these methods are + // tracked under the given queue in addition to the global LSN-based tracking. + + [[nodiscard]] EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); + [[nodiscard]] EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + + // Completed action tracking + + [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); + [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); + void RetireActionResult(int ActionLsn); + + // Action rescheduling + + struct RescheduleResult + { + bool Success = false; + std::string Error; + int RetryCount = 0; + }; + + [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + + void GetCompleted(CbWriter&); + + // Running action tracking + + struct RunningActionInfo + { + int Lsn; + int QueueId; + IoHash ActionId; + float CpuUsagePercent; // -1.0 if not yet sampled + float CpuSeconds; // 0.0 if not yet sampled + }; + + [[nodiscard]] std::vector GetRunningActions(); + + // Action history tracking (note that this is separate from completed action tracking, and + // will include actions which have been retired and no longer have their results available) + + struct ActionHistoryEntry + { + int Lsn; + int QueueId = 0; + IoHash ActionId; + IoHash WorkerId; + CbObject ActionDescriptor; + std::string ExecutionLocation; + bool Succeeded; + float CpuSeconds = 0.0f; // total CPU time at completion; 0.0 if not sampled + int RetryCount = 0; // number of times this action was rescheduled + // sized to match RunnerAction::State::_Count but we can't use the enum here + // for dependency reasons, so just use a fixed size array and static assert in + // the implementation file + uint64_t Timestamps[8] = {}; + }; + + [[nodiscard]] std::vector GetActionHistory(int Limit = 100); + [[nodiscard]] std::vector GetQueueHistory(int QueueId, int Limit = 100); + + // Stats reporting + + struct ActionCounts + { + int Pending = 0; + int Running = 0; + int Completed = 0; + int ActiveQueues = 0; + }; + + [[nodiscard]] ActionCounts GetActionCounts(); + + void EmitStats(CbObjectWriter& Cbo); + + // Recording + + void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); + void StopRecording(); + +private: + void PostUpdate(RunnerAction* Action); + + friend class FunctionRunner; + friend struct RunnerAction; + + struct Impl; + std::unique_ptr m_Impl; +}; + +void computeservice_forcelink(); + +} // namespace zen::compute + +namespace zen { +const char* ToString(compute::ComputeServiceSession::SessionState State); +const char* ToString(compute::ComputeServiceSession::QueueState State); +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/functionservice.h b/src/zencompute/include/zencompute/functionservice.h deleted file mode 100644 index 1deb99fd5..000000000 --- a/src/zencompute/include/zencompute/functionservice.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include -# include -# include -# include - -# include - -namespace zen { -class ChunkResolver; -class CbObjectWriter; -} // namespace zen - -namespace zen::compute { - -class ActionRecorder; -class FunctionServiceSession; -class IActionResultHandler; -class LocalProcessRunner; -class RemoteHttpRunner; -struct RunnerAction; -struct SubmitResult; - -struct WorkerDesc -{ - CbPackage Descriptor; - IoHash WorkerId{IoHash::Zero}; - - inline operator bool() const { return WorkerId != IoHash::Zero; } -}; - -/** - * Lambda style compute function service - * - * The responsibility of this class is to accept function execution requests, and - * schedule them using one or more FunctionRunner instances. It will basically always - * accept requests, queueing them if necessary, and then hand them off to runners - * as they become available. - * - * This is typically fronted by an API service that handles communication with clients. - */ -class FunctionServiceSession final -{ -public: - FunctionServiceSession(ChunkResolver& InChunkResolver); - ~FunctionServiceSession(); - - void Shutdown(); - bool IsHealthy(); - - // Worker registration and discovery - - void RegisterWorker(CbPackage Worker); - [[nodiscard]] WorkerDesc GetWorkerDescriptor(const IoHash& WorkerId); - [[nodiscard]] std::vector GetKnownWorkerIds(); - - // Action runners - - void AddLocalRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath); - void AddRemoteRunner(ChunkResolver& InChunkResolver, std::filesystem::path BasePath, std::string_view HostName); - - // Action submission - - struct EnqueueResult - { - int Lsn; - CbObject ResponseMessage; - - inline operator bool() const { return Lsn != 0; } - }; - - [[nodiscard]] EnqueueResult EnqueueResolvedAction(WorkerDesc Worker, CbObject ActionObj, int Priority); - [[nodiscard]] EnqueueResult EnqueueAction(CbObject ActionObject, int Priority); - - // Completed action tracking - - [[nodiscard]] HttpResponseCode GetActionResult(int ActionLsn, CbPackage& OutResultPackage); - [[nodiscard]] HttpResponseCode FindActionResult(const IoHash& ActionId, CbPackage& ResultPackage); - - void GetCompleted(CbWriter&); - - // Action history tracking (note that this is separate from completed action tracking, and - // will include actions which have been retired and no longer have their results available) - - struct ActionHistoryEntry - { - int Lsn; - IoHash ActionId; - IoHash WorkerId; - CbObject ActionDescriptor; - bool Succeeded; - uint64_t Timestamps[5] = {}; - }; - - [[nodiscard]] std::vector GetActionHistory(int Limit = 100); - - // Stats reporting - - void EmitStats(CbObjectWriter& Cbo); - - // Recording - - void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); - void StopRecording(); - -private: - void PostUpdate(RunnerAction* Action); - - friend class FunctionRunner; - friend struct RunnerAction; - - struct Impl; - std::unique_ptr m_Impl; -}; - -void function_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h new file mode 100644 index 000000000..ee1cd2614 --- /dev/null +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "zencompute/computeservice.h" + +# include + +# include +# include + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** + * HTTP interface for compute service + */ +class HttpComputeService : public HttpService, public IHttpStatsProvider +{ +public: + HttpComputeService(CidStore& InCidStore, + IHttpStatsService& StatsService, + const std::filesystem::path& BaseDir, + int32_t MaxConcurrentActions = 0); + ~HttpComputeService(); + + void Shutdown(); + + [[nodiscard]] ComputeServiceSession::ActionCounts GetActionCounts(); + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + + // IHttpStatsProvider + + void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +void httpcomputeservice_forcelink(); + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httpfunctionservice.h b/src/zencompute/include/zencompute/httpfunctionservice.h deleted file mode 100644 index 6e2344ae6..000000000 --- a/src/zencompute/include/zencompute/httpfunctionservice.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if !defined(ZEN_WITH_COMPUTE_SERVICES) -# define ZEN_WITH_COMPUTE_SERVICES 1 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "zencompute/functionservice.h" - -# include -# include -# include -# include -# include -# include - -# include -# include -# include - -namespace zen { -class CidStore; -} - -namespace zen::compute { - -class HttpFunctionService; -class FunctionService; - -/** - * HTTP interface for compute function service - */ -class HttpFunctionService : public HttpService, public IHttpStatsProvider -{ -public: - HttpFunctionService(CidStore& InCidStore, IHttpStatsService& StatsService, const std::filesystem::path& BaseDir); - ~HttpFunctionService(); - - void Shutdown(); - - virtual const char* BaseUri() const override; - virtual void HandleRequest(HttpServerRequest& Request) override; - - // IHttpStatsProvider - - virtual void HandleStatsRequest(HttpServerRequest& Request) override; - -protected: - CidStore& m_CidStore; - IHttpStatsService& m_StatsService; - LoggerRef Log() { return m_Log; } - -private: - LoggerRef m_Log; - std::filesystem ::path m_BaseDir; - HttpRequestRouter m_Router; - FunctionServiceSession m_FunctionService; - - // Metrics - - metrics::OperationTiming m_HttpRequests; -}; - -void httpfunction_forcelink(); - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/include/zencompute/httporchestrator.h b/src/zencompute/include/zencompute/httporchestrator.h index 168c6d7fe..da5c5dfc3 100644 --- a/src/zencompute/include/zencompute/httporchestrator.h +++ b/src/zencompute/include/zencompute/httporchestrator.h @@ -2,43 +2,100 @@ #pragma once +#include + #include #include -#include #include +#include +#include +#include +#include +#include +#include #include +#include + +#define ZEN_WITH_WEBSOCKETS 1 namespace zen::compute { +class OrchestratorService; + +// Experimental helper, to see if we can get rid of some error-prone +// boilerplate when declaring loggers as class members. + +class LoggerHelper +{ +public: + LoggerHelper(std::string_view Logger) : m_Log(logging::Get(Logger)) {} + + LoggerRef operator()() { return m_Log; } + +private: + LoggerRef m_Log; +}; + /** - * Mock orchestrator service, for testing dynamic provisioning + * Orchestrator HTTP service with WebSocket push support + * + * Normal HTTP requests are routed through the HttpRequestRouter as before. + * WebSocket clients connecting to /orch/ws receive periodic state broadcasts + * from a dedicated push thread, eliminating the need for polling. */ class HttpOrchestratorService : public HttpService +#if ZEN_WITH_WEBSOCKETS +, + public IWebSocketHandler +#endif { public: - HttpOrchestratorService(); + explicit HttpOrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); ~HttpOrchestratorService(); HttpOrchestratorService(const HttpOrchestratorService&) = delete; HttpOrchestratorService& operator=(const HttpOrchestratorService&) = delete; + /** + * Gracefully shut down the WebSocket push thread and release connections. + * Must be called while the ASIO io_context is still alive. The destructor + * also calls this, so it is safe (but not ideal) to omit the explicit call. + */ + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; + // IWebSocketHandler +#if ZEN_WITH_WEBSOCKETS + void OnWebSocketOpen(Ref Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; +#endif + private: - HttpRequestRouter m_Router; - LoggerRef m_Log; + HttpRequestRouter m_Router; + LoggerHelper Log{"orch"}; + std::unique_ptr m_Service; + std::string m_Hostname; + + // WebSocket push - struct KnownWorker - { - std::string_view BaseUri; - Stopwatch LastSeen; - }; +#if ZEN_WITH_WEBSOCKETS + RwLock m_WsConnectionsLock; + std::vector> m_WsConnections; + std::thread m_PushThread; + std::atomic m_PushEnabled{false}; + Event m_PushEvent; + void PushThreadFunction(); - RwLock m_KnownWorkersLock; - std::unordered_map m_KnownWorkers; + // Worker WebSocket connections (worker→orchestrator persistent links) + RwLock m_WorkerWsLock; + std::unordered_map m_WorkerWsMap; // connection ptr → worker ID + std::string HandleWorkerWebSocketMessage(const WebSocketMessage& Msg); +#endif }; } // namespace zen::compute diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h new file mode 100644 index 000000000..521722e63 --- /dev/null +++ b/src/zencompute/include/zencompute/mockimds.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/include/zencompute/orchestratorservice.h b/src/zencompute/include/zencompute/orchestratorservice.h new file mode 100644 index 000000000..071e902b3 --- /dev/null +++ b/src/zencompute/include/zencompute/orchestratorservice.h @@ -0,0 +1,177 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include + +# include +# include +# include +# include +# include +# include +# include +# include + +namespace zen::compute { + +class WorkerTimelineStore; + +class OrchestratorService +{ +public: + explicit OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket = false); + ~OrchestratorService(); + + OrchestratorService(const OrchestratorService&) = delete; + OrchestratorService& operator=(const OrchestratorService&) = delete; + + struct WorkerAnnouncement + { + std::string_view Id; + std::string_view Uri; + std::string_view Hostname; + std::string_view Platform; // e.g. "windows", "wine", "linux", "macos" + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string_view Provisioner; // e.g. "horde", "nomad", or empty + }; + + struct ProvisioningEvent + { + enum class Type + { + Joined, + Left, + Returned + }; + Type EventType; + DateTime Timestamp; + std::string WorkerId; + std::string Hostname; + }; + + struct ClientAnnouncement + { + Oid SessionId; + std::string_view Hostname; + std::string_view Address; + CbObject Metadata; + }; + + struct ClientEvent + { + enum class Type + { + Connected, + Disconnected, + Updated + }; + Type EventType; + DateTime Timestamp; + std::string ClientId; + std::string Hostname; + }; + + CbObject GetWorkerList(); + void AnnounceWorker(const WorkerAnnouncement& Announcement); + + bool IsWorkerWebSocketEnabled() const; + void SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected); + + CbObject GetProvisioningHistory(int Limit = 100); + + CbObject GetWorkerTimeline(std::string_view WorkerId, std::optional From, std::optional To, int Limit); + + CbObject GetAllTimelines(DateTime From, DateTime To); + + std::string AnnounceClient(const ClientAnnouncement& Announcement); + bool UpdateClient(std::string_view ClientId, CbObject Metadata = {}); + bool CompleteClient(std::string_view ClientId); + CbObject GetClientList(); + CbObject GetClientHistory(int Limit = 100); + +private: + enum class ReachableState + { + Unknown, + Reachable, + Unreachable, + }; + + struct KnownWorker + { + std::string BaseUri; + Stopwatch LastSeen; + std::string Hostname; + std::string Platform; + int Cpus = 0; + float CpuUsagePercent = 0.0f; + uint64_t MemoryTotalBytes = 0; + uint64_t MemoryUsedBytes = 0; + uint64_t BytesReceived = 0; + uint64_t BytesSent = 0; + int ActionsPending = 0; + int ActionsRunning = 0; + int ActionsCompleted = 0; + int ActiveQueues = 0; + std::string Provisioner; + ReachableState Reachable = ReachableState::Unknown; + bool WsConnected = false; + Stopwatch LastProbed; + }; + + RwLock m_KnownWorkersLock; + std::unordered_map m_KnownWorkers; + std::unique_ptr m_TimelineStore; + + RwLock m_ProvisioningLogLock; + std::deque m_ProvisioningLog; + static constexpr size_t kMaxProvisioningEvents = 1000; + + void RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname); + + struct KnownClient + { + Oid SessionId; + std::string Hostname; + std::string Address; + Stopwatch LastSeen; + CbObject Metadata; + }; + + RwLock m_KnownClientsLock; + std::unordered_map m_KnownClients; + + RwLock m_ClientLogLock; + std::deque m_ClientLog; + static constexpr size_t kMaxClientEvents = 1000; + + void RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname); + + bool m_EnableWorkerWebSocket = false; + + std::thread m_ProbeThread; + std::atomic m_ProbeThreadEnabled{true}; + Event m_ProbeThreadEvent; + void ProbeThreadFunction(); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/include/zencompute/recordingreader.h b/src/zencompute/include/zencompute/recordingreader.h index bf1aff125..3f233fae0 100644 --- a/src/zencompute/include/zencompute/recordingreader.h +++ b/src/zencompute/include/zencompute/recordingreader.h @@ -2,7 +2,9 @@ #pragma once -#include +#include + +#include #include #include #include diff --git a/src/zencompute/include/zencompute/zencompute.h b/src/zencompute/include/zencompute/zencompute.h index 6dc32eeea..00be4d4a0 100644 --- a/src/zencompute/include/zencompute/zencompute.h +++ b/src/zencompute/include/zencompute/zencompute.h @@ -4,6 +4,10 @@ #include +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + namespace zen { void zencompute_forcelinktests(); diff --git a/src/zencompute/localrunner.cpp b/src/zencompute/localrunner.cpp deleted file mode 100644 index 9a27f3f3d..000000000 --- a/src/zencompute/localrunner.cpp +++ /dev/null @@ -1,722 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "localrunner.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include - -namespace zen::compute { - -using namespace std::literals; - -LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir) -: FunctionRunner(BaseDir) -, m_Log(logging::Get("local_exec")) -, m_ChunkResolver(Resolver) -, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) -, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) -{ - SystemMetrics Sm = GetSystemMetricsForReporting(); - - m_MaxRunningActions = Sm.LogicalProcessorCount * 2; - - ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); - - bool DidCleanup = false; - - if (std::filesystem::is_directory(m_ActionsPath)) - { - ZEN_INFO("Cleaning '{}'", m_ActionsPath); - - std::error_code Ec; - CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); - - if (Ec) - { - ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); - } - - DidCleanup = true; - } - - if (std::filesystem::is_directory(m_SandboxPath)) - { - ZEN_INFO("Cleaning '{}'", m_SandboxPath); - std::error_code Ec; - CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); - - if (Ec) - { - ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); - } - - DidCleanup = true; - } - - // We clean out all workers on startup since we can't know they are good. They could be bad - // due to tampering, malware (which I also mean to include AV and antimalware software) or - // other processes we have no control over - if (std::filesystem::is_directory(m_WorkerPath)) - { - ZEN_INFO("Cleaning '{}'", m_WorkerPath); - std::error_code Ec; - CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); - - if (Ec) - { - ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); - } - - DidCleanup = true; - } - - if (DidCleanup) - { - ZEN_INFO("Cleanup complete"); - } - - m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; - -# if ZEN_PLATFORM_WINDOWS - // Suppress any error dialogs caused by missing dependencies - UINT OldMode = ::SetErrorMode(0); - ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); -# endif - - m_AcceptNewActions = true; -} - -LocalProcessRunner::~LocalProcessRunner() -{ - try - { - Shutdown(); - } - catch (std::exception& Ex) - { - ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); - } -} - -void -LocalProcessRunner::Shutdown() -{ - m_AcceptNewActions = false; - - m_MonitorThreadEnabled = false; - m_MonitorThreadEvent.Set(); - if (m_MonitorThread.joinable()) - { - m_MonitorThread.join(); - } - - CancelRunningActions(); -} - -std::filesystem::path -LocalProcessRunner::CreateNewSandbox() -{ - std::string UniqueId = std::to_string(++m_SandboxCounter); - std::filesystem::path Path = m_SandboxPath / UniqueId; - zen::CreateDirectories(Path); - - return Path; -} - -void -LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) -{ - if (m_DumpActions) - { - CbObject WorkerDescriptor = WorkerPackage.GetObject(); - const IoHash& WorkerId = WorkerPackage.GetObjectHash(); - - std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); - std::filesystem::path Path = m_ActionsPath / UniqueId; - - zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); - - ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { - std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); - zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); - }); - - ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); - } -} - -size_t -LocalProcessRunner::QueryCapacity() -{ - // Estimate how much more work we're ready to accept - - RwLock::SharedLockScope _{m_RunningLock}; - - if (!m_AcceptNewActions) - { - return 0; - } - - size_t RunningCount = m_RunningMap.size(); - - if (RunningCount >= size_t(m_MaxRunningActions)) - { - return 0; - } - - return m_MaxRunningActions - RunningCount; -} - -std::vector -LocalProcessRunner::SubmitActions(const std::vector>& Actions) -{ - std::vector Results; - - for (const Ref& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -SubmitResult -LocalProcessRunner::SubmitAction(Ref Action) -{ - // Verify whether we can accept more work - - { - RwLock::SharedLockScope _{m_RunningLock}; - - if (!m_AcceptNewActions) - { - return SubmitResult{.IsAccepted = false}; - } - - if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) - { - return SubmitResult{.IsAccepted = false}; - } - } - - using namespace std::literals; - - // Each enqueued action is assigned an integer index (logical sequence number), - // which we use as a key for tracking data structures and as an opaque id which - // may be used by clients to reference the scheduled action - - const int32_t ActionLsn = Action->ActionLsn; - const CbObject& ActionObj = Action->ActionObj; - const IoHash ActionId = ActionObj.GetHash(); - - MaybeDumpAction(ActionLsn, ActionObj); - - std::filesystem::path SandboxPath = CreateNewSandbox(); - - CbPackage WorkerPackage = Action->Worker.Descriptor; - - std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); - - // Write out action - - zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); - - // Manifest inputs in sandbox - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash Cid = Field.AsHash(); - std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; - IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); - - if (!DataBuffer) - { - throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); - } - - zen::WriteFile(FilePath, DataBuffer); - }); - -# if ZEN_PLATFORM_WINDOWS - // Set up environment variables - - StringBuilder<1024> EnvironmentBlock; - - CbObject WorkerDescription = WorkerPackage.GetObject(); - - for (auto& It : WorkerDescription["environment"sv]) - { - EnvironmentBlock.Append(It.AsString()); - EnvironmentBlock.Append('\0'); - } - EnvironmentBlock.Append('\0'); - EnvironmentBlock.Append('\0'); - - // Execute process - this spawns the child process immediately without waiting - // for completion - - std::string_view ExecPath = WorkerDescription["path"sv].AsString(); - std::filesystem::path ExePath = WorkerPath / std::filesystem::path(ExecPath).make_preferred(); - - ExtendableWideStringBuilder<512> CommandLine; - CommandLine.Append(L'"'); - CommandLine.Append(ExePath.c_str()); - CommandLine.Append(L'"'); - CommandLine.Append(L" -Build=build.action"); - - LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; - LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; - BOOL bInheritHandles = FALSE; - DWORD dwCreationFlags = 0; - - STARTUPINFO StartupInfo{}; - StartupInfo.cb = sizeof StartupInfo; - - PROCESS_INFORMATION ProcessInformation{}; - - ZEN_DEBUG("Executing: {}", WideToUtf8(CommandLine.c_str())); - - CommandLine.EnsureNulTerminated(); - - BOOL Success = CreateProcessW(nullptr, - CommandLine.Data(), - lpProcessAttributes, - lpThreadAttributes, - bInheritHandles, - dwCreationFlags, - (LPVOID)EnvironmentBlock.Data(), // Environment block - SandboxPath.c_str(), // Current directory - &StartupInfo, - /* out */ &ProcessInformation); - - if (!Success) - { - // TODO: this is probably not the best way to report failure. The return - // object should include a failure state and context - - zen::ThrowLastError("Unable to launch process" /* TODO: Add context */); - } - - CloseHandle(ProcessInformation.hThread); - - Ref NewAction{new RunningAction()}; - NewAction->Action = Action; - NewAction->ProcessHandle = ProcessInformation.hProcess; - NewAction->SandboxPath = std::move(SandboxPath); - - { - RwLock::ExclusiveLockScope _(m_RunningLock); - - m_RunningMap[ActionLsn] = std::move(NewAction); - } - - Action->SetActionState(RunnerAction::State::Running); -# else - ZEN_UNUSED(ActionId); - - ZEN_NOT_IMPLEMENTED(); - - int ExitCode = 0; -# endif - - return SubmitResult{.IsAccepted = true}; -} - -size_t -LocalProcessRunner::GetSubmittedActionCount() -{ - RwLock::SharedLockScope _(m_RunningLock); - return m_RunningMap.size(); -} - -std::filesystem::path -LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) -{ - RwLock::SharedLockScope _(m_WorkerLock); - - std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); - - if (!std::filesystem::exists(WorkerDir)) - { - _.ReleaseNow(); - - RwLock::ExclusiveLockScope $(m_WorkerLock); - - if (!std::filesystem::exists(WorkerDir)) - { - ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); - } - } - - return WorkerDir; -} - -void -LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, - CbObjectView FileEntry, - const std::filesystem::path& SandboxRootPath, - std::function& ChunkReferenceCallback) -{ - std::string_view Name = FileEntry["name"sv].AsString(); - const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); - const uint64_t Size = FileEntry["size"sv].AsUInt64(); - - CompressedBuffer Compressed; - - if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) - { - Compressed = Attachment->AsCompressedBinary(); - } - else - { - IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); - - if (!DataBuffer) - { - throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); - } - - uint64_t DataRawSize = 0; - IoHash DataRawHash; - Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); - - if (DataRawSize != Size) - { - throw std::runtime_error( - fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); - } - } - - ChunkReferenceCallback(ChunkHash, Compressed); - - std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; - - SharedBuffer Decompressed = Compressed.Decompress(); - zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); -} - -void -LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, - const std::filesystem::path& SandboxPath, - std::function&& ChunkReferenceCallback) -{ - CbObject WorkerDescription = WorkerPackage.GetObject(); - - // Manifest worker in Sandbox - - for (auto& It : WorkerDescription["executables"sv]) - { - DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); - } - - for (auto& It : WorkerDescription["dirs"sv]) - { - std::string_view Name = It.AsString(); - std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; - zen::CreateDirectories(DirPath); - } - - for (auto& It : WorkerDescription["files"sv]) - { - DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); - } - - WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); -} - -CbPackage -LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) -{ - std::filesystem::path OutputFile = SandboxPath / "build.output"; - FileContents OutputData = zen::ReadFile(OutputFile); - - if (OutputData.ErrorCode) - { - throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); - } - - CbPackage OutputPackage; - CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); - - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalRawAttachmentBytes = 0; - - Output.IterateAttachments([&](CbFieldView Field) { - IoHash Hash = Field.AsHash(); - std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; - FileContents ChunkData = zen::ReadFile(OutputPath); - - if (ChunkData.ErrorCode) - { - throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); - } - - uint64_t ChunkDataRawSize = 0; - IoHash ChunkDataHash; - CompressedBuffer AttachmentBuffer = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); - - if (!AttachmentBuffer) - { - throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); - } - - TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); - TotalRawAttachmentBytes += ChunkDataRawSize; - - CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); - OutputPackage.AddAttachment(Attachment); - }); - - OutputPackage.SetObject(Output); - - ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", - OutputPackage.GetAttachments().size(), - NiceBytes(TotalAttachmentBytes), - NiceBytes(TotalRawAttachmentBytes)); - - return OutputPackage; -} - -void -LocalProcessRunner::MonitorThreadFunction() -{ - SetCurrentThreadName("LocalProcessRunner_Monitor"); - - auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); - - do - { - // On Windows it's possible to wait on process handles, so we wait for either a process to exit - // or for the monitor event to be signaled (which indicates we should check for cancellation - // or shutdown). This could be further improved by using a completion port and registering process - // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing - // with an enormous number of concurrent processes. - // - // On other platforms we just wait on the monitor event and poll for process exits at intervals. -# if ZEN_PLATFORM_WINDOWS - auto WaitOnce = [&] { - HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; - - uint32_t NumHandles = 0; - - WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); - - m_RunningLock.WithSharedLock([&] { - for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) - { - Ref Action = It->second; - - WaitHandles[NumHandles++] = Action->ProcessHandle; - } - }); - - DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); - - // return true if a handle was signaled - return (WaitResult <= NumHandles); - }; -# else - auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; -# endif - - while (!WaitOnce()) - { - if (m_MonitorThreadEnabled == false) - { - return; - } - - SweepRunningActions(); - } - - // Signal received - - SweepRunningActions(); - } while (m_MonitorThreadEnabled); -} - -void -LocalProcessRunner::CancelRunningActions() -{ - Stopwatch Timer; - std::unordered_map> RunningMap; - - m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); - - if (RunningMap.empty()) - { - return; - } - - ZEN_INFO("cancelling all running actions"); - - // For expedience we initiate the process termination for all known - // processes before attempting to wait for them to exit. - - std::vector TerminatedLsnList; - - for (const auto& Kv : RunningMap) - { - Ref Action = Kv.second; - - // Terminate running process - -# if ZEN_PLATFORM_WINDOWS - BOOL Success = TerminateProcess(Action->ProcessHandle, 222); - - if (Success) - { - TerminatedLsnList.push_back(Kv.first); - } - else - { - DWORD LastError = GetLastError(); - - if (LastError != ERROR_ACCESS_DENIED) - { - ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Action->Action->ActionLsn, GetSystemErrorAsString(LastError)); - } - } -# else - ZEN_NOT_IMPLEMENTED("need to implement process termination"); -# endif - } - - // We only post results for processes we have terminated, in order - // to avoid multiple results getting posted for the same action - - for (int Lsn : TerminatedLsnList) - { - if (auto It = RunningMap.find(Lsn); It != RunningMap.end()) - { - Ref Running = It->second; - -# if ZEN_PLATFORM_WINDOWS - if (Running->ProcessHandle != INVALID_HANDLE_VALUE) - { - DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); - - if (WaitResult != WAIT_OBJECT_0) - { - ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); - } - else - { - ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); - } - } -# endif - - // Clean up and post error result - - DeleteDirectories(Running->SandboxPath); - Running->Action->SetActionState(RunnerAction::State::Failed); - } - } - - ZEN_INFO("DONE - cancelled {} running processes (took {})", TerminatedLsnList.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); -} - -void -LocalProcessRunner::SweepRunningActions() -{ - std::vector> CompletedActions; - - m_RunningLock.WithExclusiveLock([&] { - // TODO: It would be good to not hold the exclusive lock while making - // system calls and other expensive operations. - - for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) - { - Ref Action = It->second; - -# if ZEN_PLATFORM_WINDOWS - DWORD ExitCode = 0; - BOOL IsSuccess = GetExitCodeProcess(Action->ProcessHandle, &ExitCode); - - if (IsSuccess && ExitCode != STILL_ACTIVE) - { - CloseHandle(Action->ProcessHandle); - Action->ProcessHandle = INVALID_HANDLE_VALUE; - - CompletedActions.push_back(std::move(Action)); - It = m_RunningMap.erase(It); - } - else - { - ++It; - } -# else - // TODO: implement properly for Mac/Linux - - ZEN_UNUSED(Action); -# endif - } - }); - - // Notify outer. Note that this has to be done without holding any local locks - // otherwise we may end up with deadlocks. - - for (Ref Running : CompletedActions) - { - const int ActionLsn = Running->Action->ActionLsn; - - if (Running->ExitCode == 0) - { - try - { - // Gather outputs - - CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); - - Running->Action->SetResult(std::move(OutputPackage)); - Running->Action->SetActionState(RunnerAction::State::Completed); - - // We can delete the files at this point - if (!DeleteDirectories(Running->SandboxPath)) - { - ZEN_WARN("Unable to delete directory '{}', this will continue to exist until service restart", Running->SandboxPath); - } - - // Success -- continue with next iteration of the loop - continue; - } - catch (std::exception& Ex) - { - ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); - } - } - - // Failed - for now this is indicated with an empty package in - // the results map. We can clean out the sandbox directory immediately. - - std::error_code Ec; - DeleteDirectories(Running->SandboxPath, Ec); - - if (Ec) - { - ZEN_WARN("Unable to delete sandbox directory '{}': {}", Running->SandboxPath, Ec.message()); - } - - Running->Action->SetActionState(RunnerAction::State::Failed); - } -} - -} // namespace zen::compute - -#endif diff --git a/src/zencompute/localrunner.h b/src/zencompute/localrunner.h deleted file mode 100644 index 35f464805..000000000 --- a/src/zencompute/localrunner.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "zencompute/functionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" - -# include -# include -# include -# include -# include - -# include -# include -# include - -namespace zen { -class CbPackage; -} - -namespace zen::compute { - -/** Direct process spawner - - This runner simply sets up a directory structure for each job and - creates a process to perform the computation in it. It is not very - efficient and is intended mostly for testing. - - */ - -class LocalProcessRunner : public FunctionRunner -{ - LocalProcessRunner(LocalProcessRunner&&) = delete; - LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; - -public: - LocalProcessRunner(ChunkResolver& Resolver, const std::filesystem::path& BaseDir); - ~LocalProcessRunner(); - - virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; - [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; - [[nodiscard]] virtual bool IsHealthy() override { return true; } - [[nodiscard]] virtual size_t GetSubmittedActionCount() override; - [[nodiscard]] virtual size_t QueryCapacity() override; - [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; - -protected: - LoggerRef Log() { return m_Log; } - - LoggerRef m_Log; - - struct RunningAction : public RefCounted - { - Ref Action; - void* ProcessHandle = nullptr; - int ExitCode = 0; - std::filesystem::path SandboxPath; - }; - - std::atomic_bool m_AcceptNewActions; - ChunkResolver& m_ChunkResolver; - RwLock m_WorkerLock; - std::filesystem::path m_WorkerPath; - std::atomic m_SandboxCounter = 0; - std::filesystem::path m_SandboxPath; - int32_t m_MaxRunningActions = 64; // arbitrary limit for testing - - // if used in conjuction with m_ResultsLock, this lock must be taken *after* - // m_ResultsLock to avoid deadlocks - RwLock m_RunningLock; - std::unordered_map> m_RunningMap; - - std::thread m_MonitorThread; - std::atomic m_MonitorThreadEnabled{true}; - Event m_MonitorThreadEvent; - void MonitorThreadFunction(); - void SweepRunningActions(); - void CancelRunningActions(); - - std::filesystem::path CreateNewSandbox(); - void ManifestWorker(const CbPackage& WorkerPackage, - const std::filesystem::path& SandboxPath, - std::function&& ChunkReferenceCallback); - std::filesystem::path ManifestWorker(const WorkerDesc& Worker); - CbPackage GatherActionOutputs(std::filesystem::path SandboxPath); - - void DecompressAttachmentToFile(const CbPackage& FromPackage, - CbObjectView FileEntry, - const std::filesystem::path& SandboxRootPath, - std::function& ChunkReferenceCallback); -}; - -} // namespace zen::compute - -#endif diff --git a/src/zencompute/orchestratorservice.cpp b/src/zencompute/orchestratorservice.cpp new file mode 100644 index 000000000..9ea695305 --- /dev/null +++ b/src/zencompute/orchestratorservice.cpp @@ -0,0 +1,710 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include + +# include "timeline/workertimeline.h" + +namespace zen::compute { + +OrchestratorService::OrchestratorService(std::filesystem::path DataDir, bool EnableWorkerWebSocket) +: m_TimelineStore(std::make_unique(DataDir / "timelines")) +, m_EnableWorkerWebSocket(EnableWorkerWebSocket) +{ + m_ProbeThread = std::thread{&OrchestratorService::ProbeThreadFunction, this}; +} + +OrchestratorService::~OrchestratorService() +{ + m_ProbeThreadEnabled = false; + m_ProbeThreadEvent.Set(); + if (m_ProbeThread.joinable()) + { + m_ProbeThread.join(); + } +} + +CbObject +OrchestratorService::GetWorkerList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerList"); + CbObjectWriter Cbo; + Cbo.BeginArray("workers"); + + m_KnownWorkersLock.WithSharedLock([&] { + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Cbo.BeginObject(); + Cbo << "id" << WorkerId; + Cbo << "uri" << Worker.BaseUri; + Cbo << "hostname" << Worker.Hostname; + if (!Worker.Platform.empty()) + { + Cbo << "platform" << std::string_view(Worker.Platform); + } + Cbo << "cpus" << Worker.Cpus; + Cbo << "cpu_usage" << Worker.CpuUsagePercent; + Cbo << "memory_total" << Worker.MemoryTotalBytes; + Cbo << "memory_used" << Worker.MemoryUsedBytes; + Cbo << "bytes_received" << Worker.BytesReceived; + Cbo << "bytes_sent" << Worker.BytesSent; + Cbo << "actions_pending" << Worker.ActionsPending; + Cbo << "actions_running" << Worker.ActionsRunning; + Cbo << "actions_completed" << Worker.ActionsCompleted; + Cbo << "active_queues" << Worker.ActiveQueues; + if (!Worker.Provisioner.empty()) + { + Cbo << "provisioner" << std::string_view(Worker.Provisioner); + } + if (Worker.Reachable != ReachableState::Unknown) + { + Cbo << "reachable" << (Worker.Reachable == ReachableState::Reachable); + } + if (Worker.WsConnected) + { + Cbo << "ws_connected" << true; + } + Cbo << "dt" << Worker.LastSeen.GetElapsedTimeMs(); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::AnnounceWorker(const WorkerAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceWorker"); + + bool IsNew = false; + std::string EvictedId; + std::string EvictedHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + IsNew = (m_KnownWorkers.find(std::string(Ann.Id)) == m_KnownWorkers.end()); + + // If a different worker ID already maps to the same URI, the old entry + // is stale (e.g. a previous Horde lease on the same machine). Remove it + // so the dashboard doesn't show duplicates. + if (IsNew) + { + for (auto It = m_KnownWorkers.begin(); It != m_KnownWorkers.end(); ++It) + { + if (It->second.BaseUri == Ann.Uri && It->first != Ann.Id) + { + EvictedId = It->first; + EvictedHostname = It->second.Hostname; + m_KnownWorkers.erase(It); + break; + } + } + } + + auto& Worker = m_KnownWorkers[std::string(Ann.Id)]; + Worker.BaseUri = Ann.Uri; + Worker.Hostname = Ann.Hostname; + if (!Ann.Platform.empty()) + { + Worker.Platform = Ann.Platform; + } + Worker.Cpus = Ann.Cpus; + Worker.CpuUsagePercent = Ann.CpuUsagePercent; + Worker.MemoryTotalBytes = Ann.MemoryTotalBytes; + Worker.MemoryUsedBytes = Ann.MemoryUsedBytes; + Worker.BytesReceived = Ann.BytesReceived; + Worker.BytesSent = Ann.BytesSent; + Worker.ActionsPending = Ann.ActionsPending; + Worker.ActionsRunning = Ann.ActionsRunning; + Worker.ActionsCompleted = Ann.ActionsCompleted; + Worker.ActiveQueues = Ann.ActiveQueues; + if (!Ann.Provisioner.empty()) + { + Worker.Provisioner = Ann.Provisioner; + } + Worker.LastSeen.Reset(); + }); + + if (!EvictedId.empty()) + { + ZEN_INFO("worker {} superseded by {} (same endpoint)", EvictedId, Ann.Id); + RecordProvisioningEvent(ProvisioningEvent::Type::Left, EvictedId, EvictedHostname); + } + + if (IsNew) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Joined, Ann.Id, Ann.Hostname); + } +} + +bool +OrchestratorService::IsWorkerWebSocketEnabled() const +{ + return m_EnableWorkerWebSocket; +} + +void +OrchestratorService::SetWorkerWebSocketConnected(std::string_view WorkerId, bool Connected) +{ + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(std::string(WorkerId)); + if (It == m_KnownWorkers.end()) + { + return; + } + + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.WsConnected = Connected; + It->second.Reachable = Connected ? ReachableState::Reachable : ReachableState::Unreachable; + + if (Connected) + { + ZEN_INFO("worker {} WebSocket connected — marking reachable", WorkerId); + } + else + { + ZEN_WARN("worker {} WebSocket disconnected — marking unreachable", WorkerId); + } + }); + + // Record provisioning events for state transitions outside the lock + if (Connected && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, WorkerId, WorkerHostname); + } + else if (!Connected && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, WorkerId, WorkerHostname); + } +} + +CbObject +OrchestratorService::GetWorkerTimeline(std::string_view WorkerId, std::optional From, std::optional To, int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetWorkerTimeline"); + + Ref Timeline = m_TimelineStore->Find(WorkerId); + if (!Timeline) + { + return {}; + } + + std::vector Events; + + if (From || To) + { + DateTime StartTime = From.value_or(DateTime(0)); + DateTime EndTime = To.value_or(DateTime::Now()); + Events = Timeline->QueryTimeline(StartTime, EndTime); + } + else if (Limit > 0) + { + Events = Timeline->QueryRecent(Limit); + } + else + { + Events = Timeline->QueryRecent(); + } + + WorkerTimeline::TimeRange Range = Timeline->GetTimeRange(); + + CbObjectWriter Cbo; + Cbo << "worker_id" << WorkerId; + Cbo << "event_count" << static_cast(Timeline->GetEventCount()); + + if (Range) + { + Cbo.AddDateTime("time_first", Range.First); + Cbo.AddDateTime("time_last", Range.Last); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : Events) + { + Cbo.BeginObject(); + Cbo << "type" << WorkerTimeline::ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == WorkerTimeline::EventType::ActionStateChanged) + { + Cbo << "prev_state" << RunnerAction::ToString(Evt.PreviousState); + Cbo << "state" << RunnerAction::ToString(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetAllTimelines(DateTime From, DateTime To) +{ + ZEN_TRACE_CPU("OrchestratorService::GetAllTimelines"); + + DateTime StartTime = From; + DateTime EndTime = To; + + auto AllInfo = m_TimelineStore->GetAllWorkerInfo(); + + CbObjectWriter Cbo; + Cbo.AddDateTime("from", StartTime); + Cbo.AddDateTime("to", EndTime); + + Cbo.BeginArray("workers"); + for (const auto& Info : AllInfo) + { + if (!Info.Range || Info.Range.Last < StartTime || Info.Range.First > EndTime) + { + continue; + } + + Cbo.BeginObject(); + Cbo << "worker_id" << Info.WorkerId; + Cbo.AddDateTime("time_first", Info.Range.First); + Cbo.AddDateTime("time_last", Info.Range.Last); + Cbo.EndObject(); + } + Cbo.EndArray(); + + return Cbo.Save(); +} + +void +OrchestratorService::RecordProvisioningEvent(ProvisioningEvent::Type Type, std::string_view WorkerId, std::string_view Hostname) +{ + ProvisioningEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .WorkerId = std::string(WorkerId), + .Hostname = std::string(Hostname), + }; + + m_ProvisioningLogLock.WithExclusiveLock([&] { + m_ProvisioningLog.push_back(std::move(Evt)); + while (m_ProvisioningLog.size() > kMaxProvisioningEvents) + { + m_ProvisioningLog.pop_front(); + } + }); +} + +CbObject +OrchestratorService::GetProvisioningHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetProvisioningHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("events"); + + m_ProvisioningLogLock.WithSharedLock([&] { + // Return last N events, newest first + int Count = 0; + for (auto It = m_ProvisioningLog.rbegin(); It != m_ProvisioningLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ProvisioningEvent::Type::Joined: + Cbo << "type" + << "joined"; + break; + case ProvisioningEvent::Type::Left: + Cbo << "type" + << "left"; + break; + case ProvisioningEvent::Type::Returned: + Cbo << "type" + << "returned"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "worker_id" << std::string_view(Evt.WorkerId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +std::string +OrchestratorService::AnnounceClient(const ClientAnnouncement& Ann) +{ + ZEN_TRACE_CPU("OrchestratorService::AnnounceClient"); + + std::string ClientId = fmt::format("client-{}", Oid::NewOid().ToString()); + + bool IsNew = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(ClientId); + IsNew = (It == m_KnownClients.end()); + + auto& Client = m_KnownClients[ClientId]; + Client.SessionId = Ann.SessionId; + Client.Hostname = Ann.Hostname; + if (!Ann.Address.empty()) + { + Client.Address = Ann.Address; + } + if (Ann.Metadata) + { + Client.Metadata = Ann.Metadata; + } + Client.LastSeen.Reset(); + }); + + if (IsNew) + { + RecordClientEvent(ClientEvent::Type::Connected, ClientId, Ann.Hostname); + } + else + { + RecordClientEvent(ClientEvent::Type::Updated, ClientId, Ann.Hostname); + } + + return ClientId; +} + +bool +OrchestratorService::UpdateClient(std::string_view ClientId, CbObject Metadata) +{ + ZEN_TRACE_CPU("OrchestratorService::UpdateClient"); + + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + if (Metadata) + { + It->second.Metadata = std::move(Metadata); + } + It->second.LastSeen.Reset(); + } + }); + + return Found; +} + +bool +OrchestratorService::CompleteClient(std::string_view ClientId) +{ + ZEN_TRACE_CPU("OrchestratorService::CompleteClient"); + + std::string Hostname; + bool Found = false; + + m_KnownClientsLock.WithExclusiveLock([&] { + auto It = m_KnownClients.find(std::string(ClientId)); + if (It != m_KnownClients.end()) + { + Found = true; + Hostname = It->second.Hostname; + m_KnownClients.erase(It); + } + }); + + if (Found) + { + RecordClientEvent(ClientEvent::Type::Disconnected, ClientId, Hostname); + } + + return Found; +} + +CbObject +OrchestratorService::GetClientList() +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientList"); + CbObjectWriter Cbo; + Cbo.BeginArray("clients"); + + m_KnownClientsLock.WithSharedLock([&] { + for (const auto& [ClientId, Client] : m_KnownClients) + { + Cbo.BeginObject(); + Cbo << "id" << ClientId; + if (Client.SessionId) + { + Cbo << "session_id" << Client.SessionId; + } + Cbo << "hostname" << std::string_view(Client.Hostname); + if (!Client.Address.empty()) + { + Cbo << "address" << std::string_view(Client.Address); + } + Cbo << "dt" << Client.LastSeen.GetElapsedTimeMs(); + if (Client.Metadata) + { + Cbo << "metadata" << Client.Metadata; + } + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +CbObject +OrchestratorService::GetClientHistory(int Limit) +{ + ZEN_TRACE_CPU("OrchestratorService::GetClientHistory"); + + if (Limit <= 0) + { + Limit = 100; + } + + CbObjectWriter Cbo; + Cbo.BeginArray("client_events"); + + m_ClientLogLock.WithSharedLock([&] { + int Count = 0; + for (auto It = m_ClientLog.rbegin(); It != m_ClientLog.rend() && Count < Limit; ++It, ++Count) + { + const auto& Evt = *It; + Cbo.BeginObject(); + + switch (Evt.EventType) + { + case ClientEvent::Type::Connected: + Cbo << "type" + << "connected"; + break; + case ClientEvent::Type::Disconnected: + Cbo << "type" + << "disconnected"; + break; + case ClientEvent::Type::Updated: + Cbo << "type" + << "updated"; + break; + } + + Cbo.AddDateTime("ts", Evt.Timestamp); + Cbo << "client_id" << std::string_view(Evt.ClientId); + Cbo << "hostname" << std::string_view(Evt.Hostname); + Cbo.EndObject(); + } + }); + + Cbo.EndArray(); + return Cbo.Save(); +} + +void +OrchestratorService::RecordClientEvent(ClientEvent::Type Type, std::string_view ClientId, std::string_view Hostname) +{ + ClientEvent Evt{ + .EventType = Type, + .Timestamp = DateTime::Now(), + .ClientId = std::string(ClientId), + .Hostname = std::string(Hostname), + }; + + m_ClientLogLock.WithExclusiveLock([&] { + m_ClientLog.push_back(std::move(Evt)); + while (m_ClientLog.size() > kMaxClientEvents) + { + m_ClientLog.pop_front(); + } + }); +} + +void +OrchestratorService::ProbeThreadFunction() +{ + ZEN_TRACE_CPU("OrchestratorService::ProbeThreadFunction"); + SetCurrentThreadName("orch_probe"); + + bool IsFirstProbe = true; + + do + { + if (!IsFirstProbe) + { + m_ProbeThreadEvent.Wait(5'000); + m_ProbeThreadEvent.Reset(); + } + else + { + IsFirstProbe = false; + } + + if (m_ProbeThreadEnabled == false) + { + return; + } + + m_ProbeThreadEvent.Reset(); + + // Snapshot worker IDs and URIs under shared lock + struct WorkerSnapshot + { + std::string Id; + std::string Uri; + bool WsConnected = false; + }; + std::vector Snapshots; + + m_KnownWorkersLock.WithSharedLock([&] { + Snapshots.reserve(m_KnownWorkers.size()); + for (const auto& [WorkerId, Worker] : m_KnownWorkers) + { + Snapshots.push_back({WorkerId, Worker.BaseUri, Worker.WsConnected}); + } + }); + + // Probe each worker outside the lock + for (const auto& Snap : Snapshots) + { + if (m_ProbeThreadEnabled == false) + { + return; + } + + // Workers with an active WebSocket connection are known-reachable; + // skip the HTTP health probe for them. + if (Snap.WsConnected) + { + continue; + } + + ReachableState NewState = ReachableState::Unreachable; + + try + { + HttpClient Client(Snap.Uri, + {.ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{5000}}); + HttpClient::Response Response = Client.Get("/health/"); + if (Response.IsSuccess()) + { + NewState = ReachableState::Reachable; + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("probe failed for worker {} ({}): {}", Snap.Id, Snap.Uri, Ex.what()); + } + + ReachableState PrevState = ReachableState::Unknown; + std::string WorkerHostname; + + m_KnownWorkersLock.WithExclusiveLock([&] { + auto It = m_KnownWorkers.find(Snap.Id); + if (It != m_KnownWorkers.end()) + { + PrevState = It->second.Reachable; + WorkerHostname = It->second.Hostname; + It->second.Reachable = NewState; + It->second.LastProbed.Reset(); + + if (PrevState != NewState) + { + if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + ZEN_INFO("worker {} ({}) is reachable again", Snap.Id, Snap.Uri); + } + else if (NewState == ReachableState::Reachable) + { + ZEN_INFO("worker {} ({}) is now reachable", Snap.Id, Snap.Uri); + } + else if (PrevState == ReachableState::Reachable) + { + ZEN_WARN("worker {} ({}) is no longer reachable", Snap.Id, Snap.Uri); + } + else + { + ZEN_WARN("worker {} ({}) is not reachable", Snap.Id, Snap.Uri); + } + } + } + }); + + // Record provisioning events for state transitions outside the lock + if (PrevState != NewState) + { + if (NewState == ReachableState::Unreachable && PrevState == ReachableState::Reachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Left, Snap.Id, WorkerHostname); + } + else if (NewState == ReachableState::Reachable && PrevState == ReachableState::Unreachable) + { + RecordProvisioningEvent(ProvisioningEvent::Type::Returned, Snap.Id, WorkerHostname); + } + } + } + + // Sweep expired clients (5-minute timeout) + static constexpr int64_t kClientTimeoutMs = 5 * 60 * 1000; + + struct ExpiredClient + { + std::string Id; + std::string Hostname; + }; + std::vector ExpiredClients; + + m_KnownClientsLock.WithExclusiveLock([&] { + for (auto It = m_KnownClients.begin(); It != m_KnownClients.end();) + { + if (It->second.LastSeen.GetElapsedTimeMs() > kClientTimeoutMs) + { + ExpiredClients.push_back({It->first, It->second.Hostname}); + It = m_KnownClients.erase(It); + } + else + { + ++It; + } + } + }); + + for (const auto& Expired : ExpiredClients) + { + ZEN_INFO("client {} timed out (no announcement for >5 minutes)", Expired.Id); + RecordClientEvent(ClientEvent::Type::Disconnected, Expired.Id, Expired.Hostname); + } + } while (m_ProbeThreadEnabled); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/recording/actionrecorder.cpp b/src/zencompute/recording/actionrecorder.cpp new file mode 100644 index 000000000..90141ca55 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.cpp @@ -0,0 +1,258 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "actionrecorder.h" + +#include "../runners/functionrunner.h" + +#include +#include +#include +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RecordingFileWriter::RecordingFileWriter() +{ +} + +RecordingFileWriter::~RecordingFileWriter() +{ + Close(); +} + +void +RecordingFileWriter::Open(std::filesystem::path FilePath) +{ + using namespace std::literals; + + m_File.Open(FilePath, BasicFile::Mode::kTruncate); + m_File.Write("----DDC2----DATA", 16, 0); + m_FileOffset = 16; + + std::filesystem::path TocPath = FilePath.replace_extension(".ztoc"); + m_TocFile.Open(TocPath, BasicFile::Mode::kTruncate); + + m_TocWriter << "version"sv << 1; + m_TocWriter.BeginArray("toc"sv); +} + +void +RecordingFileWriter::Close() +{ + m_TocWriter.EndArray(); + CbObject Toc = m_TocWriter.Save(); + + std::error_code Ec; + m_TocFile.WriteAll(Toc.GetBuffer().AsIoBuffer(), Ec); +} + +void +RecordingFileWriter::AppendObject(const CbObject& Object, const IoHash& ObjectHash) +{ + RwLock::ExclusiveLockScope _(m_FileLock); + + MemoryView ObjectView = Object.GetBuffer().GetView(); + + std::error_code Ec; + m_File.Write(ObjectView, m_FileOffset, Ec); + + if (Ec) + { + throw std::system_error(Ec, "failed writing to archive"); + } + + m_TocWriter.BeginArray(); + m_TocWriter.AddHash(ObjectHash); + m_TocWriter.AddInteger(m_FileOffset); + m_TocWriter.AddInteger(gsl::narrow(ObjectView.GetSize())); + m_TocWriter.EndArray(); + + m_FileOffset += ObjectView.GetSize(); +} + +////////////////////////////////////////////////////////////////////////// + +ActionRecorder::ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath) +: m_ChunkResolver(InChunkResolver) +, m_RecordingLogDir(RecordingLogPath) +{ + std::error_code Ec; + CreateDirectories(m_RecordingLogDir, Ec); + + if (Ec) + { + ZEN_WARN("Could not create directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + CleanDirectory(m_RecordingLogDir, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Could not clean directory '{}': {}", m_RecordingLogDir, Ec.message()); + } + + m_WorkersFile.Open(m_RecordingLogDir / "workers.zdat"); + m_ActionsFile.Open(m_RecordingLogDir / "actions.zdat"); + + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +ActionRecorder::~ActionRecorder() +{ + Shutdown(); +} + +void +ActionRecorder::Shutdown() +{ + m_CidStore.Flush(); +} + +void +ActionRecorder::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + m_WorkersFile.AppendObject(WorkerPackage.GetObject(), WorkerId); + + std::unordered_set AddedChunks; + uint64_t AddedBytes = 0; + + // First add all attachments from the worker package itself + + for (const CbAttachment& Attachment : WorkerPackage.GetAttachments()) + { + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + IoBuffer Data = Buffer.GetCompressed().Flatten().AsIoBuffer(); + + const IoHash ChunkHash = Buffer.DecodeRawHash(); + + CidStore::InsertResult Result = m_CidStore.AddChunk(Data, ChunkHash, CidStore::InsertMode::kCopyOnly); + + AddedChunks.insert(ChunkHash); + + if (Result.New) + { + AddedBytes += Data.GetSize(); + } + } + + // Not all attachments will be present in the worker package, so we need to add + // all referenced chunks to ensure that the recording is self-contained and not + // referencing data in the main CID store + + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + + WorkerDescriptor.IterateAttachments([&](const CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + + if (!AddedChunks.contains(AttachmentCid)) + { + IoBuffer AttachmentData = m_ChunkResolver.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + CidStore::InsertResult Result = m_CidStore.AddChunk(AttachmentData, AttachmentCid, CidStore::InsertMode::kCopyOnly); + + if (Result.New) + { + AddedBytes += AttachmentData.GetSize(); + } + } + else + { + ZEN_WARN("RegisterWorker: could not resolve attachment chunk {} for worker {}", AttachmentCid, WorkerId); + } + + AddedChunks.insert(AttachmentCid); + } + }); + + ZEN_INFO("recorded worker {} with {} attachments ({} bytes)", WorkerId, AddedChunks.size(), AddedBytes); +} + +bool +ActionRecorder::RecordAction(Ref Action) +{ + bool AllGood = true; + + Action->ActionObj.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsHash(); + IoBuffer ChunkData = m_ChunkResolver.FindChunkByCid(AttachData); + + if (ChunkData) + { + if (ChunkData.GetContentType() == ZenContentType::kCompressedBinary) + { + IoHash DecompressedHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), /* out */ DecompressedHash, /* out*/ RawSize); + + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + uint64_t BlockSize = 0; + if (Compressed.TryGetCompressParameters(/* out */ Compressor, /* out */ CompressionLevel, /* out */ BlockSize)) + { + if (Compressor == OodleCompressor::NotSet) + { + CompositeBuffer Decompressed = Compressed.DecompressToComposite(); + CompressedBuffer NewCompressed = CompressedBuffer::Compress(std::move(Decompressed), + OodleCompressor::Mermaid, + OodleCompressionLevel::Fast, + BlockSize); + + ChunkData = NewCompressed.GetCompressed().Flatten().AsIoBuffer(); + } + } + } + + const uint64_t ChunkSize = ChunkData.GetSize(); + + m_CidStore.AddChunk(ChunkData, AttachData, CidStore::InsertMode::kCopyOnly); + ++m_ChunkCounter; + m_ChunkBytesCounter.fetch_add(ChunkSize); + } + else + { + AllGood = false; + + ZEN_WARN("could not resolve chunk {}", AttachData); + } + }); + + if (AllGood) + { + m_ActionsFile.AppendObject(Action->ActionObj, Action->ActionId); + ++m_ActionsCounter; + + return true; + } + else + { + return false; + } +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/actionrecorder.h b/src/zencompute/recording/actionrecorder.h new file mode 100644 index 000000000..2827b6ac7 --- /dev/null +++ b/src/zencompute/recording/actionrecorder.h @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace zen { +class CbObject; +class CbPackage; +struct IoHash; +} // namespace zen + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +////////////////////////////////////////////////////////////////////////// + +struct RecordingFileWriter +{ + RecordingFileWriter(RecordingFileWriter&&) = delete; + RecordingFileWriter& operator=(RecordingFileWriter&&) = delete; + + RwLock m_FileLock; + BasicFile m_File; + uint64_t m_FileOffset = 0; + CbObjectWriter m_TocWriter; + BasicFile m_TocFile; + + RecordingFileWriter(); + ~RecordingFileWriter(); + + void Open(std::filesystem::path FilePath); + void Close(); + void AppendObject(const CbObject& Object, const IoHash& ObjectHash); +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Recording "runner" implementation + * + * This class writes out all actions and their attachments to a recording directory + * in a format that can be read back by the RecordingReader. + * + * The contents of the recording directory will be self-contained, with all referenced + * attachments stored in the recording directory itself, so that the recording can be + * moved or shared without needing to maintain references to the main CID store. + * + */ + +class ActionRecorder +{ +public: + ActionRecorder(ChunkResolver& InChunkResolver, const std::filesystem::path& RecordingLogPath); + ~ActionRecorder(); + + ActionRecorder(const ActionRecorder&) = delete; + ActionRecorder& operator=(const ActionRecorder&) = delete; + + void Shutdown(); + void RegisterWorker(const CbPackage& WorkerPackage); + bool RecordAction(Ref Action); + +private: + ChunkResolver& m_ChunkResolver; + std::filesystem::path m_RecordingLogDir; + + RecordingFileWriter m_WorkersFile; + RecordingFileWriter m_ActionsFile; + GcManager m_Gc; + CidStore m_CidStore{m_Gc}; + std::atomic m_ChunkCounter{0}; + std::atomic m_ChunkBytesCounter{0}; + std::atomic m_ActionsCounter{0}; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recording/recordingreader.cpp b/src/zencompute/recording/recordingreader.cpp new file mode 100644 index 000000000..1c1a119cf --- /dev/null +++ b/src/zencompute/recording/recordingreader.cpp @@ -0,0 +1,335 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencompute/recordingreader.h" + +#include +#include +#include +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# define ZEN_CONCRT_AVAILABLE 1 +#else +# define ZEN_CONCRT_AVAILABLE 0 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_PLATFORM_WINDOWS +# define ZEN_BUILD_ACTION L"Build.action" +# define ZEN_WORKER_UCB L"worker.ucb" +# else +# define ZEN_BUILD_ACTION "Build.action" +# define ZEN_WORKER_UCB "worker.ucb" +# endif + +////////////////////////////////////////////////////////////////////////// + +struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor +{ + virtual void VisitFile(const std::filesystem::path& Parent, + const path_view& File, + uint64_t FileSize, + uint32_t NativeModeOrAttributes, + uint64_t NativeModificationTick) + { + ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick); + + if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0) + { + WorkDirs.push_back(Parent); + } + else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0) + { + WorkerDirs.push_back(Parent); + } + } + + virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes) + { + ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes); + + return true; + } + + std::vector WorkerDirs; + std::vector WorkDirs; +}; + +////////////////////////////////////////////////////////////////////////// + +void +IterateOverArray(auto Array, auto Func, int TargetParallelism) +{ +# if ZEN_CONCRT_AVAILABLE + if (TargetParallelism > 1) + { + concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism); + concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); }); + + return; + } +# else + ZEN_UNUSED(TargetParallelism); +# endif + + for (const auto& Item : Array) + { + Func(Item); + } +} + +////////////////////////////////////////////////////////////////////////// + +RecordingReaderBase::~RecordingReaderBase() = default; + +////////////////////////////////////////////////////////////////////////// + +RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath) +{ + CidStoreConfiguration CidConfig; + CidConfig.RootDirectory = m_RecordingLogDir / "cid"; + CidConfig.HugeValueThreshold = 128 * 1024 * 1024; + + m_CidStore.Initialize(CidConfig); +} + +RecordingReader::~RecordingReader() +{ + m_CidStore.Flush(); +} + +size_t +RecordingReader::GetActionCount() const +{ + return m_Actions.size(); +} + +IoBuffer +RecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId)) + { + return Chunk; + } + + ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId); + + return {}; +} + +std::unordered_map +RecordingReader::ReadWorkers() +{ + std::unordered_map WorkerMap; + + { + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc"); + CbObject Toc = TocFile.Object; + + m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView Entry = It.AsArrayView(); + CbFieldViewIterator Vit = Entry.CreateViewIterator(); + + const IoHash WorkerId = Vit++->AsHash(); + const uint64_t Offset = Vit++->AsInt64(0); + const uint64_t Size = Vit++->AsInt64(0); + + IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size); + CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange); + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid); + + if (AttachmentData) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + } + }); + } + } + + // Scan actions as well (this should be called separately, ideally) + + ScanActions(); + + return WorkerMap; +} + +void +RecordingReader::ScanActions() +{ + CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc"); + CbObject Toc = TocFile.Object; + + m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead); + + ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); + + for (auto& It : Toc["toc"]) + { + CbArrayView ArrayEntry = It.AsArrayView(); + CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator(); + + ActionEntry Entry; + Entry.ActionId = Vit++->AsHash(); + Entry.Offset = Vit++->AsInt64(0); + Entry.Size = Vit++->AsInt64(0); + + m_Actions.push_back(Entry); + } +} + +void +RecordingReader::IterateActions(std::function&& Callback, int TargetParallelism) +{ + IterateOverArray( + m_Actions, + [&](const ActionEntry& Entry) { + CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size)); + + Callback(ActionDesc, Entry.ActionId); + }, + TargetParallelism); +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer +LocalResolver::FindChunkByCid(const IoHash& DecompressedId) +{ + RwLock::SharedLockScope _(MapLock); + if (auto It = Attachments.find(DecompressedId); It != Attachments.end()) + { + return It->second; + } + + return {}; +} + +void +LocalResolver::Add(const IoHash& Cid, IoBuffer Data) +{ + RwLock::ExclusiveLockScope _(MapLock); + Data.SetContentType(ZenContentType::kCompressedBinary); + Attachments[Cid] = Data; +} + +/// + +UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath) +{ +} + +UeRecordingReader::~UeRecordingReader() +{ +} + +size_t +UeRecordingReader::GetActionCount() const +{ + return m_WorkDirs.size(); +} + +IoBuffer +UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId) +{ + return m_LocalResolver.FindChunkByCid(DecompressedId); +} + +std::unordered_map +UeRecordingReader::ReadWorkers() +{ + std::unordered_map WorkerMap; + + FileSystemTraversal Traversal; + RecordingTreeVisitor Visitor; + Traversal.TraverseFileSystem(m_RecordingDir, Visitor); + + m_WorkDirs = std::move(Visitor.WorkDirs); + + for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs) + { + CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb"); + CbObject WorkerDesc = WorkerFile.Object; + const IoHash& WorkerId = WorkerFile.Hash; + CbPackage& WorkerPkg = WorkerMap[WorkerId]; + WorkerPkg.SetObject(WorkerDesc); + + WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten(); + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + } + + return WorkerMap; +} + +void +UeRecordingReader::IterateActions(std::function&& Callback, int ParallelismTarget) +{ + IterateOverArray( + m_WorkDirs, + [&](const std::filesystem::path& WorkDir) { + CbPackage WorkPackage = ReadAction(WorkDir); + CbObject ActionObject = WorkPackage.GetObject(); + const IoHash& ActionId = WorkPackage.GetObjectHash(); + + Callback(ActionObject, ActionId); + }, + ParallelismTarget); +} + +CbPackage +UeRecordingReader::ReadAction(std::filesystem::path WorkDir) +{ + CbPackage WorkPackage; + std::filesystem::path WorkDescPath = WorkDir / "Build.action"; + CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath); + CbObject& ActionObject = ActionFile.Object; + + WorkPackage.SetObject(ActionObject); + + ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { + const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); + IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten(); + + m_LocalResolver.Add(AttachmentCid, AttachmentData); + + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); + ZEN_ASSERT(AttachmentCid == RawHash); + WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash)); + }); + + return WorkPackage; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/recordingreader.cpp b/src/zencompute/recordingreader.cpp deleted file mode 100644 index 1c1a119cf..000000000 --- a/src/zencompute/recordingreader.cpp +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "zencompute/recordingreader.h" - -#include -#include -#include -#include -#include -#include - -#if ZEN_PLATFORM_WINDOWS -# include -# define ZEN_CONCRT_AVAILABLE 1 -#else -# define ZEN_CONCRT_AVAILABLE 0 -#endif - -#if ZEN_WITH_COMPUTE_SERVICES - -namespace zen::compute { - -using namespace std::literals; - -////////////////////////////////////////////////////////////////////////// - -# if ZEN_PLATFORM_WINDOWS -# define ZEN_BUILD_ACTION L"Build.action" -# define ZEN_WORKER_UCB L"worker.ucb" -# else -# define ZEN_BUILD_ACTION "Build.action" -# define ZEN_WORKER_UCB "worker.ucb" -# endif - -////////////////////////////////////////////////////////////////////////// - -struct RecordingTreeVisitor : public FileSystemTraversal::TreeVisitor -{ - virtual void VisitFile(const std::filesystem::path& Parent, - const path_view& File, - uint64_t FileSize, - uint32_t NativeModeOrAttributes, - uint64_t NativeModificationTick) - { - ZEN_UNUSED(Parent, File, FileSize, NativeModeOrAttributes, NativeModificationTick); - - if (File.compare(path_view(ZEN_BUILD_ACTION)) == 0) - { - WorkDirs.push_back(Parent); - } - else if (File.compare(path_view(ZEN_WORKER_UCB)) == 0) - { - WorkerDirs.push_back(Parent); - } - } - - virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName, uint32_t NativeModeOrAttributes) - { - ZEN_UNUSED(Parent, DirectoryName, NativeModeOrAttributes); - - return true; - } - - std::vector WorkerDirs; - std::vector WorkDirs; -}; - -////////////////////////////////////////////////////////////////////////// - -void -IterateOverArray(auto Array, auto Func, int TargetParallelism) -{ -# if ZEN_CONCRT_AVAILABLE - if (TargetParallelism > 1) - { - concurrency::simple_partitioner Chunker(Array.size() / TargetParallelism); - concurrency::parallel_for_each(begin(Array), end(Array), [&](const auto& Item) { Func(Item); }); - - return; - } -# else - ZEN_UNUSED(TargetParallelism); -# endif - - for (const auto& Item : Array) - { - Func(Item); - } -} - -////////////////////////////////////////////////////////////////////////// - -RecordingReaderBase::~RecordingReaderBase() = default; - -////////////////////////////////////////////////////////////////////////// - -RecordingReader::RecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingLogDir(RecordingPath) -{ - CidStoreConfiguration CidConfig; - CidConfig.RootDirectory = m_RecordingLogDir / "cid"; - CidConfig.HugeValueThreshold = 128 * 1024 * 1024; - - m_CidStore.Initialize(CidConfig); -} - -RecordingReader::~RecordingReader() -{ - m_CidStore.Flush(); -} - -size_t -RecordingReader::GetActionCount() const -{ - return m_Actions.size(); -} - -IoBuffer -RecordingReader::FindChunkByCid(const IoHash& DecompressedId) -{ - if (IoBuffer Chunk = m_CidStore.FindChunkByCid(DecompressedId)) - { - return Chunk; - } - - ZEN_ERROR("failed lookup of chunk with CID '{}'", DecompressedId); - - return {}; -} - -std::unordered_map -RecordingReader::ReadWorkers() -{ - std::unordered_map WorkerMap; - - { - CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "workers.ztoc"); - CbObject Toc = TocFile.Object; - - m_WorkerDataFile.Open(m_RecordingLogDir / "workers.zdat", BasicFile::Mode::kRead); - - ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); - - for (auto& It : Toc["toc"]) - { - CbArrayView Entry = It.AsArrayView(); - CbFieldViewIterator Vit = Entry.CreateViewIterator(); - - const IoHash WorkerId = Vit++->AsHash(); - const uint64_t Offset = Vit++->AsInt64(0); - const uint64_t Size = Vit++->AsInt64(0); - - IoBuffer WorkerRange = m_WorkerDataFile.ReadRange(Offset, Size); - CbObject WorkerDesc = LoadCompactBinaryObject(WorkerRange); - CbPackage& WorkerPkg = WorkerMap[WorkerId]; - WorkerPkg.SetObject(WorkerDesc); - - WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { - const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - IoBuffer AttachmentData = m_CidStore.FindChunkByCid(AttachmentCid); - - if (AttachmentData) - { - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); - WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); - } - }); - } - } - - // Scan actions as well (this should be called separately, ideally) - - ScanActions(); - - return WorkerMap; -} - -void -RecordingReader::ScanActions() -{ - CbObjectFromFile TocFile = LoadCompactBinaryObject(m_RecordingLogDir / "actions.ztoc"); - CbObject Toc = TocFile.Object; - - m_ActionDataFile.Open(m_RecordingLogDir / "actions.zdat", BasicFile::Mode::kRead); - - ZEN_ASSERT(Toc["version"sv].AsInt32() == 1); - - for (auto& It : Toc["toc"]) - { - CbArrayView ArrayEntry = It.AsArrayView(); - CbFieldViewIterator Vit = ArrayEntry.CreateViewIterator(); - - ActionEntry Entry; - Entry.ActionId = Vit++->AsHash(); - Entry.Offset = Vit++->AsInt64(0); - Entry.Size = Vit++->AsInt64(0); - - m_Actions.push_back(Entry); - } -} - -void -RecordingReader::IterateActions(std::function&& Callback, int TargetParallelism) -{ - IterateOverArray( - m_Actions, - [&](const ActionEntry& Entry) { - CbObject ActionDesc = LoadCompactBinaryObject(m_ActionDataFile.ReadRange(Entry.Offset, Entry.Size)); - - Callback(ActionDesc, Entry.ActionId); - }, - TargetParallelism); -} - -////////////////////////////////////////////////////////////////////////// - -IoBuffer -LocalResolver::FindChunkByCid(const IoHash& DecompressedId) -{ - RwLock::SharedLockScope _(MapLock); - if (auto It = Attachments.find(DecompressedId); It != Attachments.end()) - { - return It->second; - } - - return {}; -} - -void -LocalResolver::Add(const IoHash& Cid, IoBuffer Data) -{ - RwLock::ExclusiveLockScope _(MapLock); - Data.SetContentType(ZenContentType::kCompressedBinary); - Attachments[Cid] = Data; -} - -/// - -UeRecordingReader::UeRecordingReader(const std::filesystem::path& RecordingPath) : m_RecordingDir(RecordingPath) -{ -} - -UeRecordingReader::~UeRecordingReader() -{ -} - -size_t -UeRecordingReader::GetActionCount() const -{ - return m_WorkDirs.size(); -} - -IoBuffer -UeRecordingReader::FindChunkByCid(const IoHash& DecompressedId) -{ - return m_LocalResolver.FindChunkByCid(DecompressedId); -} - -std::unordered_map -UeRecordingReader::ReadWorkers() -{ - std::unordered_map WorkerMap; - - FileSystemTraversal Traversal; - RecordingTreeVisitor Visitor; - Traversal.TraverseFileSystem(m_RecordingDir, Visitor); - - m_WorkDirs = std::move(Visitor.WorkDirs); - - for (const std::filesystem::path& WorkerDir : Visitor.WorkerDirs) - { - CbObjectFromFile WorkerFile = LoadCompactBinaryObject(WorkerDir / "worker.ucb"); - CbObject WorkerDesc = WorkerFile.Object; - const IoHash& WorkerId = WorkerFile.Hash; - CbPackage& WorkerPkg = WorkerMap[WorkerId]; - WorkerPkg.SetObject(WorkerDesc); - - WorkerDesc.IterateAttachments([&](const zen::CbFieldView AttachmentField) { - const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - IoBuffer AttachmentData = ReadFile(WorkerDir / "chunks" / AttachmentCid.ToHexString()).Flatten(); - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); - WorkerPkg.AddAttachment(CbAttachment(CompressedData, RawHash)); - }); - } - - return WorkerMap; -} - -void -UeRecordingReader::IterateActions(std::function&& Callback, int ParallelismTarget) -{ - IterateOverArray( - m_WorkDirs, - [&](const std::filesystem::path& WorkDir) { - CbPackage WorkPackage = ReadAction(WorkDir); - CbObject ActionObject = WorkPackage.GetObject(); - const IoHash& ActionId = WorkPackage.GetObjectHash(); - - Callback(ActionObject, ActionId); - }, - ParallelismTarget); -} - -CbPackage -UeRecordingReader::ReadAction(std::filesystem::path WorkDir) -{ - CbPackage WorkPackage; - std::filesystem::path WorkDescPath = WorkDir / "Build.action"; - CbObjectFromFile ActionFile = LoadCompactBinaryObject(WorkDescPath); - CbObject& ActionObject = ActionFile.Object; - - WorkPackage.SetObject(ActionObject); - - ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { - const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - IoBuffer AttachmentData = ReadFile(WorkDir / "inputs" / AttachmentCid.ToHexString()).Flatten(); - - m_LocalResolver.Add(AttachmentCid, AttachmentData); - - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); - ZEN_ASSERT(AttachmentCid == RawHash); - WorkPackage.AddAttachment(CbAttachment(CompressedData, RawHash)); - }); - - return WorkPackage; -} - -} // namespace zen::compute - -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/remotehttprunner.cpp deleted file mode 100644 index 98ced5fe8..000000000 --- a/src/zencompute/remotehttprunner.cpp +++ /dev/null @@ -1,457 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "remotehttprunner.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include - -////////////////////////////////////////////////////////////////////////// - -namespace zen::compute { - -using namespace std::literals; - -////////////////////////////////////////////////////////////////////////// - -RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName) -: FunctionRunner(BaseDir) -, m_Log(logging::Get("http_exec")) -, m_ChunkResolver{InChunkResolver} -, m_BaseUrl{fmt::format("{}/apply", HostName)} -, m_Http(m_BaseUrl) -{ - m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; -} - -RemoteHttpRunner::~RemoteHttpRunner() -{ - Shutdown(); -} - -void -RemoteHttpRunner::Shutdown() -{ - // TODO: should cleanly drain/cancel pending work - - m_MonitorThreadEnabled = false; - m_MonitorThreadEvent.Set(); - if (m_MonitorThread.joinable()) - { - m_MonitorThread.join(); - } -} - -void -RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) -{ - const IoHash WorkerId = WorkerPackage.GetObjectHash(); - CbPackage WorkerDesc = WorkerPackage; - - std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); - - HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); - - if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) - { - HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); - - if (DescResponse.StatusCode == HttpResponseCode::NotFound) - { - CbPackage Pkg = WorkerDesc; - - // Build response package by sending only the attachments - // the other end needs. We start with the full package and - // remove the attachments which are not needed. - - { - std::unordered_set Needed; - - CbObject Response = DescResponse.AsObject(); - - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - Needed.insert(NeedHash); - } - - std::unordered_set ToRemove; - - for (const CbAttachment& Attachment : Pkg.GetAttachments()) - { - const IoHash& Hash = Attachment.GetHash(); - - if (Needed.find(Hash) == Needed.end()) - { - ToRemove.insert(Hash); - } - } - - for (const IoHash& Hash : ToRemove) - { - int RemovedCount = Pkg.RemoveAttachment(Hash); - - ZEN_ASSERT(RemovedCount == 1); - } - } - - // Post resulting package - - HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); - - if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) - { - ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error - } - } - else if (!IsHttpSuccessCode(DescResponse.StatusCode)) - { - ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error - } - else - { - ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); - } - } - else if (WorkerResponse.StatusCode == HttpResponseCode::OK) - { - // Already known from a previous run - } - else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) - { - ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", - WorkerId, - m_Http.GetBaseUri(), - WorkerUrl, - (int)WorkerResponse.StatusCode, - ToString(WorkerResponse.StatusCode)); - - // TODO: propagate error - } -} - -size_t -RemoteHttpRunner::QueryCapacity() -{ - // Estimate how much more work we're ready to accept - - RwLock::SharedLockScope _{m_RunningLock}; - - size_t RunningCount = m_RemoteRunningMap.size(); - - if (RunningCount >= size_t(m_MaxRunningActions)) - { - return 0; - } - - return m_MaxRunningActions - RunningCount; -} - -std::vector -RemoteHttpRunner::SubmitActions(const std::vector>& Actions) -{ - std::vector Results; - - for (const Ref& Action : Actions) - { - Results.push_back(SubmitAction(Action)); - } - - return Results; -} - -SubmitResult -RemoteHttpRunner::SubmitAction(Ref Action) -{ - // Verify whether we can accept more work - - { - RwLock::SharedLockScope _{m_RunningLock}; - if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) - { - return SubmitResult{.IsAccepted = false}; - } - } - - using namespace std::literals; - - // Each enqueued action is assigned an integer index (logical sequence number), - // which we use as a key for tracking data structures and as an opaque id which - // may be used by clients to reference the scheduled action - - const int32_t ActionLsn = Action->ActionLsn; - const CbObject& ActionObj = Action->ActionObj; - const IoHash ActionId = ActionObj.GetHash(); - - MaybeDumpAction(ActionLsn, ActionObj); - - // Enqueue job - - CbObject Result; - - HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj); - HttpResponseCode WorkResponseCode = WorkResponse.StatusCode; - - if (WorkResponseCode == HttpResponseCode::OK) - { - Result = WorkResponse.AsObject(); - } - else if (WorkResponseCode == HttpResponseCode::NotFound) - { - // Not all attachments are present - - // Build response package including all required attachments - - CbPackage Pkg; - Pkg.SetObject(ActionObj); - - CbObject Response = WorkResponse.AsObject(); - - for (auto& Item : Response["need"sv]) - { - const IoHash NeedHash = Item.AsHash(); - - if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) - { - uint64_t DataRawSize = 0; - IoHash DataRawHash; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); - - ZEN_ASSERT(DataRawHash == NeedHash); - - Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); - } - else - { - // No such attachment - - return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; - } - } - - // Post resulting package - - HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg); - - if (!PayloadResponse) - { - ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri()); - - // TODO: include more information about the failure in the response - - return {.IsAccepted = false, .Reason = "HTTP request failed"}; - } - else if (PayloadResponse.StatusCode == HttpResponseCode::OK) - { - Result = PayloadResponse.AsObject(); - } - else - { - // Unexpected response - - const int ResponseStatusCode = (int)PayloadResponse.StatusCode; - - ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})", - ActionId, - m_Http.GetBaseUri(), - ResponseStatusCode, - ToString(ResponseStatusCode)); - - return {.IsAccepted = false, - .Reason = fmt::format("unexpected response code {} {} from {}/jobs", - ResponseStatusCode, - ToString(ResponseStatusCode), - m_Http.GetBaseUri())}; - } - } - - if (Result) - { - if (const int32_t LsnField = Result["lsn"].AsInt32(0)) - { - HttpRunningAction NewAction; - NewAction.Action = Action; - NewAction.RemoteActionLsn = LsnField; - - { - RwLock::ExclusiveLockScope _(m_RunningLock); - - m_RemoteRunningMap[LsnField] = std::move(NewAction); - } - - ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); - - Action->SetActionState(RunnerAction::State::Running); - - return SubmitResult{.IsAccepted = true}; - } - } - - return {}; -} - -bool -RemoteHttpRunner::IsHealthy() -{ - if (HttpClient::Response Ready = m_Http.Get("/ready")) - { - return true; - } - else - { - // TODO: use response to propagate context - return false; - } -} - -size_t -RemoteHttpRunner::GetSubmittedActionCount() -{ - RwLock::SharedLockScope _(m_RunningLock); - return m_RemoteRunningMap.size(); -} - -void -RemoteHttpRunner::MonitorThreadFunction() -{ - SetCurrentThreadName("RemoteHttpRunner_Monitor"); - - do - { - const int NormalWaitingTime = 1000; - int WaitTimeMs = NormalWaitingTime; - auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; - auto SweepOnce = [&] { - const size_t RetiredCount = SweepRunningActions(); - - m_RunningLock.WithSharedLock([&] { - if (m_RemoteRunningMap.size() > 16) - { - WaitTimeMs = NormalWaitingTime / 4; - } - else - { - if (RetiredCount) - { - WaitTimeMs = NormalWaitingTime / 2; - } - else - { - WaitTimeMs = NormalWaitingTime; - } - } - }); - }; - - while (!WaitOnce()) - { - SweepOnce(); - } - - // Signal received - this may mean we should quit - - SweepOnce(); - } while (m_MonitorThreadEnabled); -} - -size_t -RemoteHttpRunner::SweepRunningActions() -{ - std::vector CompletedActions; - - // Poll remote for list of completed actions - - HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); - - if (CbObject Completed = ResponseCompleted.AsObject()) - { - for (auto& FieldIt : Completed["completed"sv]) - { - const int32_t CompleteLsn = FieldIt.AsInt32(); - - if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn))) - { - m_RunningLock.WithExclusiveLock([&] { - if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) - { - HttpRunningAction CompletedAction = std::move(CompleteIt->second); - CompletedAction.ActionResults = ResponseJob.AsPackage(); - CompletedAction.Success = true; - - CompletedActions.push_back(std::move(CompletedAction)); - m_RemoteRunningMap.erase(CompleteIt); - } - else - { - // we received a completion notice for an action we don't know about, - // this can happen if the runner is used by multiple upstream schedulers, - // or if this compute node was recently restarted and lost track of - // previously scheduled actions - } - }); - } - } - - if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) - { - // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) - if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) - { - const int32_t NewCap = zen::Max(4, CpuCount); - - if (m_MaxRunningActions > NewCap) - { - ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); - - m_MaxRunningActions = NewCap; - } - } - } - } - - // Notify outer. Note that this has to be done without holding any local locks - // otherwise we may end up with deadlocks. - - for (HttpRunningAction& HttpAction : CompletedActions) - { - const int ActionLsn = HttpAction.Action->ActionLsn; - - if (HttpAction.Success) - { - ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn); - - HttpAction.Action->SetActionState(RunnerAction::State::Completed); - - HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); - } - else - { - HttpAction.Action->SetActionState(RunnerAction::State::Failed); - } - } - - return CompletedActions.size(); -} - -} // namespace zen::compute - -#endif diff --git a/src/zencompute/remotehttprunner.h b/src/zencompute/remotehttprunner.h deleted file mode 100644 index 1e885da3d..000000000 --- a/src/zencompute/remotehttprunner.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "zencompute/functionservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include "functionrunner.h" - -# include -# include -# include -# include - -# include -# include -# include - -namespace zen { -class CidStore; -} - -namespace zen::compute { - -/** HTTP-based runner - - This implements a DDC remote compute execution strategy via REST API - - */ - -class RemoteHttpRunner : public FunctionRunner -{ - RemoteHttpRunner(RemoteHttpRunner&&) = delete; - RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; - -public: - RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName); - ~RemoteHttpRunner(); - - virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; - [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; - [[nodiscard]] virtual bool IsHealthy() override; - [[nodiscard]] virtual size_t GetSubmittedActionCount() override; - [[nodiscard]] virtual size_t QueryCapacity() override; - [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; - -protected: - LoggerRef Log() { return m_Log; } - -private: - LoggerRef m_Log; - ChunkResolver& m_ChunkResolver; - std::string m_BaseUrl; - HttpClient m_Http; - - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing - - struct HttpRunningAction - { - Ref Action; - int RemoteActionLsn = 0; // Remote LSN - bool Success = false; - CbPackage ActionResults; - }; - - RwLock m_RunningLock; - std::unordered_map m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn - - std::thread m_MonitorThread; - std::atomic m_MonitorThreadEnabled{true}; - Event m_MonitorThreadEvent; - void MonitorThreadFunction(); - size_t SweepRunningActions(); -}; - -} // namespace zen::compute - -#endif diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp new file mode 100644 index 000000000..00977d9fa --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -0,0 +1,336 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "deferreddeleter.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include + +# include +# include + +namespace zen::compute { + +using namespace std::chrono_literals; + +using Clock = std::chrono::steady_clock; + +// Default deferral: how long to wait before attempting deletion. +// This gives memory-mapped file handles time to close naturally. +static constexpr auto DeferralPeriod = 60s; + +// Shortened deferral after MarkReady(): the client has collected results +// so handles should be released soon, but we still wait briefly. +static constexpr auto ReadyGracePeriod = 5s; + +// Interval between retry attempts for directories that failed deletion. +static constexpr auto RetryInterval = 5s; + +static constexpr int MaxRetries = 10; + +DeferredDirectoryDeleter::DeferredDirectoryDeleter() : m_Thread(&DeferredDirectoryDeleter::ThreadFunction, this) +{ +} + +DeferredDirectoryDeleter::~DeferredDirectoryDeleter() +{ + Shutdown(); +} + +void +DeferredDirectoryDeleter::Enqueue(int ActionLsn, std::filesystem::path Path) +{ + { + std::lock_guard Lock(m_Mutex); + m_Queue.push_back({ActionLsn, std::move(Path)}); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::MarkReady(int ActionLsn) +{ + { + std::lock_guard Lock(m_Mutex); + m_ReadyLsns.push_back(ActionLsn); + } + m_Cv.notify_one(); +} + +void +DeferredDirectoryDeleter::Shutdown() +{ + { + std::lock_guard Lock(m_Mutex); + m_Done = true; + } + m_Cv.notify_one(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } +} + +void +DeferredDirectoryDeleter::ThreadFunction() +{ + SetCurrentThreadName("ZenDirCleanup"); + + struct PendingEntry + { + int ActionLsn; + std::filesystem::path Path; + Clock::time_point ReadyTime; + int Attempts = 0; + }; + + std::vector PendingList; + + auto TryDelete = [](PendingEntry& Entry) -> bool { + std::error_code Ec; + std::filesystem::remove_all(Entry.Path, Ec); + return !Ec; + }; + + for (;;) + { + bool Shutting = false; + + // Drain the incoming queue and process MarkReady signals + + { + std::unique_lock Lock(m_Mutex); + + if (m_Queue.empty() && m_ReadyLsns.empty() && !m_Done) + { + if (PendingList.empty()) + { + m_Cv.wait(Lock, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + else + { + auto NextReady = PendingList.front().ReadyTime; + for (const auto& Entry : PendingList) + { + if (Entry.ReadyTime < NextReady) + { + NextReady = Entry.ReadyTime; + } + } + + m_Cv.wait_until(Lock, NextReady, [this] { return !m_Queue.empty() || !m_ReadyLsns.empty() || m_Done; }); + } + } + + // Move new items into PendingList with the full deferral deadline + auto Now = Clock::now(); + for (auto& Entry : m_Queue) + { + PendingList.push_back({Entry.ActionLsn, std::move(Entry.Path), Now + DeferralPeriod, 0}); + } + m_Queue.clear(); + + // Apply MarkReady: shorten ReadyTime for matching entries + for (int Lsn : m_ReadyLsns) + { + for (auto& Entry : PendingList) + { + if (Entry.ActionLsn == Lsn) + { + auto NewReady = Now + ReadyGracePeriod; + if (NewReady < Entry.ReadyTime) + { + Entry.ReadyTime = NewReady; + } + } + } + } + m_ReadyLsns.clear(); + + Shutting = m_Done; + } + + // Process items whose deferral period has elapsed (or all items on shutdown) + + auto Now = Clock::now(); + + for (size_t i = 0; i < PendingList.size();) + { + auto& Entry = PendingList[i]; + + if (!Shutting && Now < Entry.ReadyTime) + { + ++i; + continue; + } + + if (TryDelete(Entry)) + { + if (Entry.Attempts > 0) + { + ZEN_INFO("Retry succeeded for directory '{}'", Entry.Path); + } + + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ++Entry.Attempts; + + if (Entry.Attempts >= MaxRetries) + { + ZEN_WARN("Giving up on deleting '{}' after {} attempts", Entry.Path, Entry.Attempts); + PendingList[i] = std::move(PendingList.back()); + PendingList.pop_back(); + } + else + { + ZEN_WARN("Unable to delete directory '{}' (attempt {}), will retry", Entry.Path, Entry.Attempts); + Entry.ReadyTime = Now + RetryInterval; + ++i; + } + } + } + + // Exit once shutdown is requested and nothing remains + + if (Shutting && PendingList.empty()) + { + return; + } + } +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS + +# include + +namespace zen::compute { + +void +deferreddeleter_forcelink() +{ +} + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include + +namespace zen::compute { + +TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path DirToDelete = TempDir.Path() / "subdir"; + CreateDirectories(DirToDelete / "nested"); + + CHECK(std::filesystem::exists(DirToDelete)); + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(1, DirToDelete); + } + + CHECK(!std::filesystem::exists(DirToDelete)); +} + +TEST_CASE("DeferredDirectoryDeleter.DeletesMultipleDirectories") +{ + ScopedTemporaryDirectory TempDir; + + constexpr int NumDirs = 10; + std::vector Dirs; + + for (int i = 0; i < NumDirs; ++i) + { + auto Dir = TempDir.Path() / std::to_string(i); + CreateDirectories(Dir / "child"); + Dirs.push_back(std::move(Dir)); + } + + { + DeferredDirectoryDeleter Deleter; + for (int i = 0; i < NumDirs; ++i) + { + CHECK(std::filesystem::exists(Dirs[i])); + Deleter.Enqueue(100 + i, Dirs[i]); + } + } + + for (const auto& Dir : Dirs) + { + CHECK(!std::filesystem::exists(Dir)); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ShutdownIsIdempotent") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "idempotent"; + CreateDirectories(Dir); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(42, Dir); + Deleter.Shutdown(); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.HandlesNonExistentPath") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path NoSuchDir = TempDir.Path() / "does_not_exist"; + + { + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(99, NoSuchDir); + } +} + +TEST_CASE("DeferredDirectoryDeleter.ExplicitShutdownBeforeDestruction") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "explicit"; + CreateDirectories(Dir / "inner"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(7, Dir); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") +{ + ScopedTemporaryDirectory TempDir; + std::filesystem::path Dir = TempDir.Path() / "markready"; + CreateDirectories(Dir / "child"); + + DeferredDirectoryDeleter Deleter; + Deleter.Enqueue(50, Dir); + + // Without MarkReady the full deferral (60s) would apply. + // MarkReady shortens it to 5s, and shutdown bypasses even that. + Deleter.MarkReady(50); + Deleter.Shutdown(); + + CHECK(!std::filesystem::exists(Dir)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/runners/deferreddeleter.h b/src/zencompute/runners/deferreddeleter.h new file mode 100644 index 000000000..9b010aa0f --- /dev/null +++ b/src/zencompute/runners/deferreddeleter.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include + +namespace zen::compute { + +/// Deletes directories on a background thread to avoid blocking callers. +/// Useful when DeleteDirectories may stall (e.g. Wine's deferred-unlink semantics). +/// +/// Enqueued directories wait for a deferral period before deletion, giving +/// file handles time to close. Call MarkReady() with the ActionLsn to shorten +/// the wait to a brief grace period (e.g. once a client has collected results). +/// On shutdown, all pending directories are deleted immediately. +class DeferredDirectoryDeleter +{ + DeferredDirectoryDeleter(const DeferredDirectoryDeleter&) = delete; + DeferredDirectoryDeleter& operator=(const DeferredDirectoryDeleter&) = delete; + +public: + DeferredDirectoryDeleter(); + ~DeferredDirectoryDeleter(); + + /// Enqueue a directory for deferred deletion, associated with an action LSN. + void Enqueue(int ActionLsn, std::filesystem::path Path); + + /// Signal that the action result has been consumed and the directory + /// can be deleted after a short grace period instead of the full deferral. + void MarkReady(int ActionLsn); + + /// Drain the queue and join the background thread. Idempotent. + void Shutdown(); + +private: + struct QueueEntry + { + int ActionLsn; + std::filesystem::path Path; + }; + + std::mutex m_Mutex; + std::condition_variable m_Cv; + std::deque m_Queue; + std::vector m_ReadyLsns; + bool m_Done = false; + std::thread m_Thread; + void ThreadFunction(); +}; + +} // namespace zen::compute + +#endif + +#if ZEN_WITH_TESTS +namespace zen::compute { +void deferreddeleter_forcelink(); // internal +} // namespace zen::compute +#endif diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp new file mode 100644 index 000000000..768cdf1e1 --- /dev/null +++ b/src/zencompute/runners/functionrunner.cpp @@ -0,0 +1,365 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include + +# include +# include + +namespace zen::compute { + +FunctionRunner::FunctionRunner(std::filesystem::path BasePath) : m_ActionsPath(BasePath / "actions") +{ +} + +FunctionRunner::~FunctionRunner() = default; + +size_t +FunctionRunner::QueryCapacity() +{ + return 1; +} + +std::vector +FunctionRunner::SubmitActions(const std::vector>& Actions) +{ + std::vector Results; + Results.reserve(Actions.size()); + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +void +FunctionRunner::MaybeDumpAction(int ActionLsn, const CbObject& ActionObject) +{ + if (m_DumpActions) + { + std::string UniqueId = fmt::format("{}.ddb", ActionLsn); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path, IoBuffer(ActionObject.GetBuffer().AsIoBuffer())); + } +} + +////////////////////////////////////////////////////////////////////////// + +void +BaseRunnerGroup::AddRunnerInternal(FunctionRunner* Runner) +{ + m_RunnersLock.WithExclusiveLock([&] { m_Runners.emplace_back(Runner); }); +} + +size_t +BaseRunnerGroup::QueryCapacity() +{ + size_t TotalCapacity = 0; + m_RunnersLock.WithSharedLock([&] { + for (const auto& Runner : m_Runners) + { + TotalCapacity += Runner->QueryCapacity(); + } + }); + return TotalCapacity; +} + +SubmitResult +BaseRunnerGroup::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitAction"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int InitialIndex = m_NextSubmitIndex.load(std::memory_order_acquire); + int Index = InitialIndex; + const int RunnerCount = gsl::narrow(m_Runners.size()); + + if (RunnerCount == 0) + { + return {.IsAccepted = false, .Reason = "No runners available"}; + } + + do + { + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + + auto& Runner = m_Runners[Index++]; + + SubmitResult Result = Runner->SubmitAction(Action); + + if (Result.IsAccepted == true) + { + m_NextSubmitIndex = Index % RunnerCount; + + return Result; + } + + while (Index >= RunnerCount) + { + Index -= RunnerCount; + } + } while (Index != InitialIndex); + + return {.IsAccepted = false}; +} + +std::vector +BaseRunnerGroup::SubmitActions(const std::vector>& Actions) +{ + ZEN_TRACE_CPU("BaseRunnerGroup::SubmitActions"); + RwLock::SharedLockScope _(m_RunnersLock); + + const int RunnerCount = gsl::narrow(m_Runners.size()); + + if (RunnerCount == 0) + { + return std::vector(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No runners available"}); + } + + // Query capacity per runner and compute total + std::vector Capacities(RunnerCount); + size_t TotalCapacity = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + Capacities[i] = m_Runners[i]->QueryCapacity(); + TotalCapacity += Capacities[i]; + } + + if (TotalCapacity == 0) + { + return std::vector(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "No capacity"}); + } + + // Distribute actions across runners proportionally to their available capacity + std::vector>> PerRunnerActions(RunnerCount); + std::vector ActionRunnerIndex(Actions.size()); + size_t ActionIdx = 0; + + for (int i = 0; i < RunnerCount; ++i) + { + if (Capacities[i] == 0) + { + continue; + } + + size_t Share = (Actions.size() * Capacities[i] + TotalCapacity - 1) / TotalCapacity; + Share = std::min(Share, Capacities[i]); + + for (size_t j = 0; j < Share && ActionIdx < Actions.size(); ++j, ++ActionIdx) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + } + } + + // Assign any remaining actions to runners with capacity (round-robin) + for (int i = 0; ActionIdx < Actions.size(); i = (i + 1) % RunnerCount) + { + if (Capacities[i] > PerRunnerActions[i].size()) + { + PerRunnerActions[i].push_back(Actions[ActionIdx]); + ActionRunnerIndex[ActionIdx] = i; + ++ActionIdx; + } + } + + // Submit batches per runner + std::vector> PerRunnerResults(RunnerCount); + + for (int i = 0; i < RunnerCount; ++i) + { + if (!PerRunnerActions[i].empty()) + { + PerRunnerResults[i] = m_Runners[i]->SubmitActions(PerRunnerActions[i]); + } + } + + // Reassemble results in original action order + std::vector Results(Actions.size()); + std::vector PerRunnerIdx(RunnerCount, 0); + + for (size_t i = 0; i < Actions.size(); ++i) + { + size_t RunnerIdx = ActionRunnerIndex[i]; + size_t Idx = PerRunnerIdx[RunnerIdx]++; + Results[i] = std::move(PerRunnerResults[RunnerIdx][Idx]); + } + + return Results; +} + +size_t +BaseRunnerGroup::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + size_t TotalCount = 0; + + for (const auto& Runner : m_Runners) + { + TotalCount += Runner->GetSubmittedActionCount(); + } + + return TotalCount; +} + +void +BaseRunnerGroup::RegisterWorker(CbPackage Worker) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->RegisterWorker(Worker); + } +} + +void +BaseRunnerGroup::Shutdown() +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->Shutdown(); + } +} + +bool +BaseRunnerGroup::CancelAction(int ActionLsn) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + if (Runner->CancelAction(ActionLsn)) + { + return true; + } + } + + return false; +} + +void +BaseRunnerGroup::CancelRemoteQueue(int QueueId) +{ + RwLock::SharedLockScope _(m_RunnersLock); + + for (auto& Runner : m_Runners) + { + Runner->CancelRemoteQueue(QueueId); + } +} + +////////////////////////////////////////////////////////////////////////// + +RunnerAction::RunnerAction(ComputeServiceSession* OwnerSession) : m_OwnerSession(OwnerSession) +{ + this->Timestamps[static_cast(State::New)] = DateTime::Now().GetTicks(); +} + +RunnerAction::~RunnerAction() +{ +} + +bool +RunnerAction::ResetActionStateToPending() +{ + // Only allow reset from Failed or Abandoned states + State CurrentState = m_ActionState.load(); + + if (CurrentState != State::Failed && CurrentState != State::Abandoned) + { + return false; + } + + if (!m_ActionState.compare_exchange_strong(CurrentState, State::Pending)) + { + return false; + } + + // Clear timestamps from Submitting through _Count + for (int i = static_cast(State::Submitting); i < static_cast(State::_Count); ++i) + { + this->Timestamps[i] = 0; + } + + // Record new Pending timestamp + this->Timestamps[static_cast(State::Pending)] = DateTime::Now().GetTicks(); + + // Clear execution fields + ExecutionLocation.clear(); + CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); + CpuSeconds.store(0.0f, std::memory_order_relaxed); + + // Increment retry count + RetryCount.fetch_add(1, std::memory_order_relaxed); + + // Re-enter the scheduler pipeline + m_OwnerSession->PostUpdate(this); + + return true; +} + +void +RunnerAction::SetActionState(State NewState) +{ + ZEN_ASSERT(NewState < State::_Count); + this->Timestamps[static_cast(NewState)] = DateTime::Now().GetTicks(); + + do + { + if (State CurrentState = m_ActionState.load(); CurrentState == NewState) + { + // No state change + return; + } + else + { + if (NewState <= CurrentState) + { + // Cannot transition to an earlier or same state + return; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, NewState)) + { + // Successful state change + + m_OwnerSession->PostUpdate(this); + + return; + } + } + } while (true); +} + +void +RunnerAction::SetResult(CbPackage&& Result) +{ + m_Result = std::move(Result); +} + +CbPackage& +RunnerAction::GetResult() +{ + ZEN_ASSERT(IsCompleted()); + return m_Result; +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h new file mode 100644 index 000000000..f67414dbb --- /dev/null +++ b/src/zencompute/runners/functionrunner.h @@ -0,0 +1,214 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include + +namespace zen::compute { + +struct SubmitResult +{ + bool IsAccepted = false; + std::string Reason; +}; + +/** Base interface for classes implementing a remote execution "runner" + */ +class FunctionRunner : public RefCounted +{ + FunctionRunner(FunctionRunner&&) = delete; + FunctionRunner& operator=(FunctionRunner&&) = delete; + +public: + FunctionRunner(std::filesystem::path BasePath); + virtual ~FunctionRunner() = 0; + + virtual void Shutdown() = 0; + virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) = 0; + [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; + [[nodiscard]] virtual bool IsHealthy() = 0; + [[nodiscard]] virtual size_t QueryCapacity(); + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions); + + // Best-effort cancellation of a specific in-flight action. Returns true if the + // cancellation signal was successfully sent. The action will transition to Cancelled + // asynchronously once the platform-level termination completes. + virtual bool CancelAction(int /*ActionLsn*/) { return false; } + + // Cancel the remote queue corresponding to the given local QueueId. + // Only meaningful for remote runners; local runners ignore this. + virtual void CancelRemoteQueue(int /*QueueId*/) {} + +protected: + std::filesystem::path m_ActionsPath; + bool m_DumpActions = false; + void MaybeDumpAction(int ActionLsn, const CbObject& ActionObject); +}; + +/** Base class for RunnerGroup that operates on generic FunctionRunner references. + * All scheduling, capacity, and lifecycle logic lives here. + */ +class BaseRunnerGroup +{ +public: + size_t QueryCapacity(); + SubmitResult SubmitAction(Ref Action); + std::vector SubmitActions(const std::vector>& Actions); + size_t GetSubmittedActionCount(); + void RegisterWorker(CbPackage Worker); + void Shutdown(); + bool CancelAction(int ActionLsn); + void CancelRemoteQueue(int QueueId); + + size_t GetRunnerCount() + { + return m_RunnersLock.WithSharedLock([this] { return m_Runners.size(); }); + } + +protected: + void AddRunnerInternal(FunctionRunner* Runner); + + RwLock m_RunnersLock; + std::vector> m_Runners; + std::atomic m_NextSubmitIndex{0}; +}; + +/** Typed RunnerGroup that adds type-safe runner addition and predicate-based removal. + */ +template +struct RunnerGroup : public BaseRunnerGroup +{ + void AddRunner(RunnerType* Runner) { AddRunnerInternal(Runner); } + + template + size_t RemoveRunnerIf(Predicate&& Pred) + { + size_t RemovedCount = 0; + m_RunnersLock.WithExclusiveLock([&] { + auto It = m_Runners.begin(); + while (It != m_Runners.end()) + { + if (Pred(static_cast(**It))) + { + (*It)->Shutdown(); + It = m_Runners.erase(It); + ++RemovedCount; + } + else + { + ++It; + } + } + }); + return RemovedCount; + } +}; + +/** + * This represents an action going through different stages of scheduling and execution. + */ +struct RunnerAction : public RefCounted +{ + explicit RunnerAction(ComputeServiceSession* OwnerSession); + ~RunnerAction(); + + int ActionLsn = 0; + int QueueId = 0; + WorkerDesc Worker; + IoHash ActionId; + CbObject ActionObj; + int Priority = 0; + std::string ExecutionLocation; // "local" or remote hostname + + // CPU usage and total CPU time of the running process, sampled periodically by the local runner. + // CpuUsagePercent: -1.0 means not yet sampled; >=0.0 is the most recent reading as a percentage. + // CpuSeconds: total CPU time (user+system) consumed since process start, in seconds. 0.0 if not yet sampled. + std::atomic CpuUsagePercent{-1.0f}; + std::atomic CpuSeconds{0.0f}; + std::atomic RetryCount{0}; + + enum class State + { + New, + Pending, + Submitting, + Running, + Completed, + Failed, + Abandoned, + Cancelled, + _Count + }; + + static const char* ToString(State _) + { + switch (_) + { + case State::New: + return "New"; + case State::Pending: + return "Pending"; + case State::Submitting: + return "Submitting"; + case State::Running: + return "Running"; + case State::Completed: + return "Completed"; + case State::Failed: + return "Failed"; + case State::Abandoned: + return "Abandoned"; + case State::Cancelled: + return "Cancelled"; + default: + return "Unknown"; + } + } + + static State FromString(std::string_view Name, State Default = State::Failed) + { + for (int i = 0; i < static_cast(State::_Count); ++i) + { + if (Name == ToString(static_cast(i))) + { + return static_cast(i); + } + } + return Default; + } + + uint64_t Timestamps[static_cast(State::_Count)] = {}; + + State ActionState() const { return m_ActionState; } + void SetActionState(State NewState); + + bool IsSuccess() const { return ActionState() == State::Completed; } + bool ResetActionStateToPending(); + bool IsCompleted() const + { + return ActionState() == State::Completed || ActionState() == State::Failed || ActionState() == State::Abandoned || + ActionState() == State::Cancelled; + } + + void SetResult(CbPackage&& Result); + CbPackage& GetResult(); + + ComputeServiceSession* GetOwnerSession() const { return m_OwnerSession; } + +private: + std::atomic m_ActionState = State::New; + ComputeServiceSession* m_OwnerSession = nullptr; + CbPackage m_Result; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES \ No newline at end of file diff --git a/src/zencompute/runners/linuxrunner.cpp b/src/zencompute/runners/linuxrunner.cpp new file mode 100644 index 000000000..e79a6c90f --- /dev/null +++ b/src/zencompute/runners/linuxrunner.cpp @@ -0,0 +1,734 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "linuxrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include +# include +# include +# include + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + int MkdirIfNeeded(const char* Path, mode_t Mode) + { + if (mkdir(Path, Mode) != 0 && errno != EEXIST) + { + return -1; + } + return 0; + } + + int BindMountReadOnly(const char* Src, const char* Dst) + { + if (mount(Src, Dst, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + return -1; + } + + // Remount read-only + if (mount(nullptr, Dst, nullptr, MS_REMOUNT | MS_BIND | MS_RDONLY | MS_REC, nullptr) != 0) + { + return -1; + } + + return 0; + } + + // Set up namespace-based sandbox isolation in the child process. + // This is called after fork(), before execve(). All operations must be + // async-signal-safe. + // + // The sandbox layout after pivot_root: + // / -> the sandbox directory (tmpfs-like, was SandboxPath) + // /usr -> bind-mount of host /usr (read-only) + // /lib -> bind-mount of host /lib (read-only) + // /lib64 -> bind-mount of host /lib64 (read-only, optional) + // /etc -> bind-mount of host /etc (read-only) + // /worker -> bind-mount of worker directory (read-only) + // /proc -> proc filesystem + // /dev -> tmpfs with null, zero, urandom + void SetupNamespaceSandbox(const char* SandboxPath, uid_t Uid, gid_t Gid, const char* WorkerPath, int ErrorPipeFd) + { + // 1. Unshare user, mount, and network namespaces + if (unshare(CLONE_NEWUSER | CLONE_NEWNS | CLONE_NEWNET) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "unshare() failed", errno); + } + + // 2. Write UID/GID mappings + // Must deny setgroups first (required by kernel for unprivileged user namespaces) + { + int Fd = open("/proc/self/setgroups", O_WRONLY); + if (Fd >= 0) + { + WriteToFd(Fd, "deny", 4); + close(Fd); + } + // setgroups file may not exist on older kernels; not fatal + } + + { + // uid_map: map our UID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast(Uid)); + + int Fd = open("/proc/self/uid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open uid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast(Len)); + close(Fd); + } + + { + // gid_map: map our GID to 0 inside the namespace + char Buf[64]; + int Len = snprintf(Buf, sizeof(Buf), "0 %u 1\n", static_cast(Gid)); + + int Fd = open("/proc/self/gid_map", O_WRONLY); + if (Fd < 0) + { + WriteErrorAndExit(ErrorPipeFd, "open gid_map failed", errno); + } + WriteToFd(Fd, Buf, static_cast(Len)); + close(Fd); + } + + // 3. Privatize the entire mount tree so our mounts don't propagate + if (mount(nullptr, "/", nullptr, MS_REC | MS_PRIVATE, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount MS_PRIVATE failed", errno); + } + + // 4. Create mount points inside the sandbox and bind-mount system directories + + // Helper macro-like pattern for building paths inside sandbox + // We use stack buffers since we can't allocate heap memory safely + char MountPoint[4096]; + + auto BuildPath = [&](const char* Suffix) -> const char* { + snprintf(MountPoint, sizeof(MountPoint), "%s/%s", SandboxPath, Suffix); + return MountPoint; + }; + + // /usr (required) + if (MkdirIfNeeded(BuildPath("usr"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/usr failed", errno); + } + if (BindMountReadOnly("/usr", BuildPath("usr")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /usr failed", errno); + } + + // /lib (required) + if (MkdirIfNeeded(BuildPath("lib"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/lib failed", errno); + } + if (BindMountReadOnly("/lib", BuildPath("lib")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /lib failed", errno); + } + + // /lib64 (optional — not all distros have it) + { + struct stat St; + if (stat("/lib64", &St) == 0 && S_ISDIR(St.st_mode)) + { + if (MkdirIfNeeded(BuildPath("lib64"), 0755) == 0) + { + BindMountReadOnly("/lib64", BuildPath("lib64")); + // Failure is non-fatal for lib64 + } + } + } + + // /etc (required — for resolv.conf, ld.so.cache, etc.) + if (MkdirIfNeeded(BuildPath("etc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/etc failed", errno); + } + if (BindMountReadOnly("/etc", BuildPath("etc")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount /etc failed", errno); + } + + // /worker — bind-mount worker directory (contains the executable) + if (MkdirIfNeeded(BuildPath("worker"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/worker failed", errno); + } + if (BindMountReadOnly(WorkerPath, BuildPath("worker")) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount worker dir failed", errno); + } + + // 5. Mount /proc inside sandbox + if (MkdirIfNeeded(BuildPath("proc"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/proc failed", errno); + } + if (mount("proc", BuildPath("proc"), "proc", MS_NOSUID | MS_NOEXEC | MS_NODEV, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount /proc failed", errno); + } + + // 6. Mount tmpfs /dev and bind-mount essential device nodes + if (MkdirIfNeeded(BuildPath("dev"), 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir sandbox/dev failed", errno); + } + if (mount("tmpfs", BuildPath("dev"), "tmpfs", MS_NOSUID | MS_NOEXEC, "size=64k,mode=0755") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mount tmpfs /dev failed", errno); + } + + // Bind-mount /dev/null, /dev/zero, /dev/urandom + { + char DevSrc[64]; + char DevDst[4096]; + + auto BindDev = [&](const char* Name) { + snprintf(DevSrc, sizeof(DevSrc), "/dev/%s", Name); + snprintf(DevDst, sizeof(DevDst), "%s/dev/%s", SandboxPath, Name); + + // Create the file to mount over + int Fd = open(DevDst, O_WRONLY | O_CREAT, 0666); + if (Fd >= 0) + { + close(Fd); + } + mount(DevSrc, DevDst, nullptr, MS_BIND, nullptr); + // Non-fatal if individual devices fail + }; + + BindDev("null"); + BindDev("zero"); + BindDev("urandom"); + } + + // 7. pivot_root to sandbox + // pivot_root requires the new root and put_old to be mount points. + // Bind-mount sandbox onto itself to make it a mount point. + if (mount(SandboxPath, SandboxPath, nullptr, MS_BIND | MS_REC, nullptr) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "bind mount sandbox onto itself failed", errno); + } + + // Create .pivot_old inside sandbox + char PivotOld[4096]; + snprintf(PivotOld, sizeof(PivotOld), "%s/.pivot_old", SandboxPath); + if (MkdirIfNeeded(PivotOld, 0755) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "mkdir .pivot_old failed", errno); + } + + if (syscall(SYS_pivot_root, SandboxPath, PivotOld) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "pivot_root failed", errno); + } + + // 8. Now inside new root. Clean up old root. + if (chdir("/") != 0) + { + WriteErrorAndExit(ErrorPipeFd, "chdir / failed", errno); + } + + if (umount2("/.pivot_old", MNT_DETACH) != 0) + { + WriteErrorAndExit(ErrorPipeFd, "umount2 .pivot_old failed", errno); + } + + rmdir("/.pivot_old"); + } + +} // anonymous namespace + +LinuxProcessRunner::LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("namespace sandboxing enabled for child processes"); + } +} + +SubmitResult +LinuxProcessRunner::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SubmitAction"); + std::optional Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: -Build=build.action + // Pre-compute all path strings before fork() for async-signal-safety. + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::string ExePathStr; + std::string SandboxedExePathStr; + + if (m_Sandboxed) + { + // After pivot_root, the worker dir is at /worker inside the new root + std::filesystem::path SandboxedExePath = std::filesystem::path("/worker") / std::filesystem::path(ExecPath); + SandboxedExePathStr = SandboxedExePath.string(); + // We still need the real path for logging + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + else + { + ExePathStr = (Prepared->WorkerPath / std::filesystem::path(ExecPath)).string(); + } + + std::string BuildArg = "-Build=build.action"; + + // argv[0] should be the path the child will see + const std::string& ChildExePath = m_Sandboxed ? SandboxedExePathStr : ExePathStr; + + std::vector ArgV; + ArgV.push_back(const_cast(ChildExePath.data())); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: get uid/gid for namespace mapping, create error pipe + uid_t CurrentUid = 0; + gid_t CurrentGid = 0; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + CurrentUid = getuid(); + CurrentGid = getgid(); + + if (pipe2(ErrorPipe, O_CLOEXEC) != 0) + { + throw zen::runtime_error("pipe2() for sandbox error pipe failed: {}", strerror(errno)); + } + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + SetupNamespaceSandbox(SandboxPathStr.c_str(), CurrentUid, CurrentGid, WorkerPathStr.c_str(), ErrorPipe[1]); + + // After pivot_root, CWD is "/" which is the sandbox root. + // execve with the sandboxed path. + execve(SandboxedExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast(static_cast(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +LinuxProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::SweepRunningActions"); + std::vector> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref Running = It->second; + + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +LinuxProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +LinuxProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("LinuxProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast(reinterpret_cast(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +static uint64_t +ReadProcStatCpuTicks(pid_t Pid) +{ + char Path[64]; + snprintf(Path, sizeof(Path), "/proc/%d/stat", static_cast(Pid)); + + char Buf[256]; + int Fd = open(Path, O_RDONLY); + if (Fd < 0) + { + return 0; + } + + ssize_t Len = read(Fd, Buf, sizeof(Buf) - 1); + close(Fd); + + if (Len <= 0) + { + return 0; + } + + Buf[Len] = '\0'; + + // Skip past "pid (name) " — find last ')' to handle names containing spaces or parens + const char* P = strrchr(Buf, ')'); + if (!P) + { + return 0; + } + + P += 2; // skip ') ' + + // Remaining fields (space-separated, 0-indexed from here): + // 0:state 1:ppid 2:pgrp 3:session 4:tty_nr 5:tty_pgrp 6:flags + // 7:minflt 8:cminflt 9:majflt 10:cmajflt 11:utime 12:stime + unsigned long UTime = 0; + unsigned long STime = 0; + sscanf(P, "%*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &UTime, &STime); + return UTime + STime; +} + +void +LinuxProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + static const long ClkTck = sysconf(_SC_CLK_TCK); + + const pid_t Pid = static_cast(reinterpret_cast(Running.ProcessHandle)); + + const uint64_t NowTicks = GetHifreqTimerValue(); + const uint64_t CurrentOsTicks = ReadProcStatCpuTicks(Pid); + + if (CurrentOsTicks == 0) + { + // Process gone or /proc entry unreadable — record timestamp without updating usage + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = 0; + return; + } + + // Cumulative CPU seconds (absolute, available from first sample) + Running.Action->CpuSeconds.store(static_cast(static_cast(CurrentOsTicks) / ClkTck), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + const float CpuPct = static_cast(static_cast(DeltaOsTicks) * 1000.0 / ClkTck / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/linuxrunner.h b/src/zencompute/runners/linuxrunner.h new file mode 100644 index 000000000..266de366b --- /dev/null +++ b/src/zencompute/runners/linuxrunner.h @@ -0,0 +1,44 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +namespace zen::compute { + +/** Native Linux process runner for executing Linux worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using Linux namespaces: + user, mount, and network namespaces are unshared so the child has no network + access and can only see the sandbox directory (with system libraries bind-mounted + read-only). This requires no special privileges thanks to user namespaces. + */ +class LinuxProcessRunner : public LocalProcessRunner +{ +public: + LinuxProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp new file mode 100644 index 000000000..7aaefb06e --- /dev/null +++ b/src/zencompute/runners/localrunner.cpp @@ -0,0 +1,674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +namespace zen::compute { + +using namespace std::literals; + +LocalProcessRunner::LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("local_exec")) +, m_ChunkResolver(Resolver) +, m_WorkerPath(std::filesystem::weakly_canonical(BaseDir / "workers")) +, m_SandboxPath(std::filesystem::weakly_canonical(BaseDir / "scratch")) +, m_DeferredDeleter(Deleter) +, m_WorkerPool(WorkerPool) +{ + SystemMetrics Sm = GetSystemMetricsForReporting(); + + m_MaxRunningActions = Sm.LogicalProcessorCount * 2; + + if (MaxConcurrentActions > 0) + { + m_MaxRunningActions = MaxConcurrentActions; + } + + ZEN_INFO("Max concurrent action count: {}", m_MaxRunningActions); + + bool DidCleanup = false; + + if (std::filesystem::is_directory(m_ActionsPath)) + { + ZEN_INFO("Cleaning '{}'", m_ActionsPath); + + std::error_code Ec; + CleanDirectory(m_ActionsPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_ActionsPath, Ec.message()); + } + + DidCleanup = true; + } + + if (std::filesystem::is_directory(m_SandboxPath)) + { + ZEN_INFO("Cleaning '{}'", m_SandboxPath); + std::error_code Ec; + CleanDirectory(m_SandboxPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_SandboxPath, Ec.message()); + } + + DidCleanup = true; + } + + // We clean out all workers on startup since we can't know they are good. They could be bad + // due to tampering, malware (which I also mean to include AV and antimalware software) or + // other processes we have no control over + if (std::filesystem::is_directory(m_WorkerPath)) + { + ZEN_INFO("Cleaning '{}'", m_WorkerPath); + std::error_code Ec; + CleanDirectory(m_WorkerPath, /* ForceRemoveReadOnlyFiles */ true, Ec); + + if (Ec) + { + ZEN_WARN("Unable to clean '{}': {}", m_WorkerPath, Ec.message()); + } + + DidCleanup = true; + } + + if (DidCleanup) + { + ZEN_INFO("Cleanup complete"); + } + + m_MonitorThread = std::thread{&LocalProcessRunner::MonitorThreadFunction, this}; + +# if ZEN_PLATFORM_WINDOWS + // Suppress any error dialogs caused by missing dependencies + UINT OldMode = ::SetErrorMode(0); + ::SetErrorMode(OldMode | SEM_FAILCRITICALERRORS); +# endif + + m_AcceptNewActions = true; +} + +LocalProcessRunner::~LocalProcessRunner() +{ + try + { + Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception during local process runner shutdown: {}", Ex.what()); + } +} + +void +LocalProcessRunner::Shutdown() +{ + ZEN_TRACE_CPU("LocalProcessRunner::Shutdown"); + m_AcceptNewActions = false; + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } + + CancelRunningActions(); +} + +std::filesystem::path +LocalProcessRunner::CreateNewSandbox() +{ + ZEN_TRACE_CPU("LocalProcessRunner::CreateNewSandbox"); + std::string UniqueId = std::to_string(++m_SandboxCounter); + std::filesystem::path Path = m_SandboxPath / UniqueId; + zen::CreateDirectories(Path); + + return Path; +} + +void +LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); + if (m_DumpActions) + { + CbObject WorkerDescriptor = WorkerPackage.GetObject(); + const IoHash& WorkerId = WorkerPackage.GetObjectHash(); + + std::string UniqueId = fmt::format("worker_{}"sv, WorkerId); + std::filesystem::path Path = m_ActionsPath / UniqueId; + + zen::WriteFile(Path / "worker.ucb", WorkerDescriptor.GetBuffer().AsIoBuffer()); + + ManifestWorker(WorkerPackage, Path / "tree", [&](const IoHash& Cid, CompressedBuffer& ChunkBuffer) { + std::filesystem::path ChunkPath = Path / "chunks" / Cid.ToHexString(); + zen::WriteFile(ChunkPath, ChunkBuffer.GetCompressed()); + }); + + ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); + } +} + +size_t +LocalProcessRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return 0; + } + + const size_t InFlightCount = m_RunningMap.size() + m_SubmittingCount.load(std::memory_order_relaxed); + + if (const size_t MaxRunningActions = m_MaxRunningActions; InFlightCount >= MaxRunningActions) + { + return 0; + } + else + { + return MaxRunningActions - InFlightCount; + } +} + +std::vector +LocalProcessRunner::SubmitActions(const std::vector>& Actions) +{ + if (Actions.size() <= 1) + { + std::vector Results; + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For nontrivial batches, check capacity upfront and accept what fits. + // Accepted actions are transitioned to Submitting and dispatched to the + // worker pool as fire-and-forget, so SubmitActions returns immediately + // and the scheduler thread is free to handle completions and updates. + + size_t Available = QueryCapacity(); + + std::vector Results(Actions.size()); + + size_t AcceptCount = std::min(Available, Actions.size()); + + for (size_t i = 0; i < AcceptCount; ++i) + { + const Ref& Action = Actions[i]; + + Action->SetActionState(RunnerAction::State::Submitting); + m_SubmittingCount.fetch_add(1, std::memory_order_relaxed); + + Results[i] = SubmitResult{.IsAccepted = true}; + + m_WorkerPool.ScheduleWork( + [this, Action]() { + auto CountGuard = MakeGuard([this] { m_SubmittingCount.fetch_sub(1, std::memory_order_relaxed); }); + + SubmitResult Result = SubmitAction(Action); + + if (!Result.IsAccepted) + { + // This might require another state? We should + // distinguish between outright rejections (e.g. invalid action) + // and transient failures (e.g. failed to launch process) which might + // be retried by the scheduler, but for now just fail the action + Action->SetActionState(RunnerAction::State::Failed); + } + }, + WorkerThreadPool::EMode::EnableBacklog); + } + + for (size_t i = AcceptCount; i < Actions.size(); ++i) + { + Results[i] = SubmitResult{.IsAccepted = false}; + } + + return Results; +} + +std::optional +LocalProcessRunner::PrepareActionSubmission(Ref Action) +{ + ZEN_TRACE_CPU("LocalProcessRunner::PrepareActionSubmission"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + + if (!m_AcceptNewActions) + { + return std::nullopt; + } + + if (m_RunningMap.size() >= size_t(m_MaxRunningActions)) + { + return std::nullopt; + } + } + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + + MaybeDumpAction(ActionLsn, ActionObj); + + std::filesystem::path SandboxPath = CreateNewSandbox(); + + // Ensure the sandbox directory is cleaned up if any subsequent step throws + auto SandboxGuard = MakeGuard([&] { m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(SandboxPath)); }); + + CbPackage WorkerPackage = Action->Worker.Descriptor; + + std::filesystem::path WorkerPath = ManifestWorker(Action->Worker); + + // Write out action + + zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + + // Manifest inputs in sandbox + + ActionObj.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + std::filesystem::path FilePath{SandboxPath / "Inputs"sv / Cid.ToHexString()}; + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(Cid); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("input CID chunk '{}' missing", Cid)); + } + + zen::WriteFile(FilePath, DataBuffer); + }); + + Action->ExecutionLocation = "local"; + + SandboxGuard.Dismiss(); + + return PreparedAction{ + .ActionLsn = ActionLsn, + .SandboxPath = std::move(SandboxPath), + .WorkerPath = std::move(WorkerPath), + .WorkerPackage = std::move(WorkerPackage), + }; +} + +SubmitResult +LocalProcessRunner::SubmitAction(Ref Action) +{ + // Base class is not directly usable — platform subclasses override this + ZEN_UNUSED(Action); + return SubmitResult{.IsAccepted = false}; +} + +size_t +LocalProcessRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RunningMap.size(); +} + +std::filesystem::path +LocalProcessRunner::ManifestWorker(const WorkerDesc& Worker) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ManifestWorker"); + RwLock::SharedLockScope _(m_WorkerLock); + + std::filesystem::path WorkerDir = m_WorkerPath / fmt::format("runner_{}", Worker.WorkerId); + + if (!std::filesystem::exists(WorkerDir)) + { + _.ReleaseNow(); + + RwLock::ExclusiveLockScope $(m_WorkerLock); + + if (!std::filesystem::exists(WorkerDir)) + { + ManifestWorker(Worker.Descriptor, WorkerDir, [](const IoHash&, CompressedBuffer&) {}); + } + } + + return WorkerDir; +} + +void +LocalProcessRunner::DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function& ChunkReferenceCallback) +{ + std::string_view Name = FileEntry["name"sv].AsString(); + const IoHash ChunkHash = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + CompressedBuffer Compressed; + + if (const CbAttachment* Attachment = FromPackage.FindAttachment(ChunkHash)) + { + Compressed = Attachment->AsCompressedBinary(); + } + else + { + IoBuffer DataBuffer = m_ChunkResolver.FindChunkByCid(ChunkHash); + + if (!DataBuffer) + { + throw std::runtime_error(fmt::format("worker chunk '{}' missing", ChunkHash)); + } + + uint64_t DataRawSize = 0; + IoHash DataRawHash; + Compressed = CompressedBuffer::FromCompressed(SharedBuffer{DataBuffer}, DataRawHash, DataRawSize); + + if (DataRawSize != Size) + { + throw std::runtime_error( + fmt::format("worker chunk '{}' size: {}, action spec expected {}", ChunkHash, DataBuffer.Size(), Size)); + } + } + + ChunkReferenceCallback(ChunkHash, Compressed); + + std::filesystem::path FilePath{SandboxRootPath / std::filesystem::path(Name).make_preferred()}; + + // Validate the resolved path stays within the sandbox to prevent directory traversal + // via malicious names like "../../etc/evil" + // + // This might be worth revisiting to frontload the validation and eliminate some memory + // allocations in the future. + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxRootPath); + std::filesystem::path CanonicalFile = std::filesystem::weakly_canonical(FilePath); + std::string RootStr = CanonicalRoot.string(); + std::string FileStr = CanonicalFile.string(); + + if (FileStr.size() < RootStr.size() || FileStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: '{}' escapes sandbox root '{}'", Name, SandboxRootPath); + } + } + + SharedBuffer Decompressed = Compressed.Decompress(); + zen::WriteFile(FilePath, Decompressed.AsIoBuffer()); +} + +void +LocalProcessRunner::ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function&& ChunkReferenceCallback) +{ + CbObject WorkerDescription = WorkerPackage.GetObject(); + + // Manifest worker in Sandbox + + for (auto& It : WorkerDescription["executables"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); +# if !ZEN_PLATFORM_WINDOWS + std::string_view ExeName = It.AsObjectView()["name"sv].AsString(); + std::filesystem::path ExePath{SandboxPath / std::filesystem::path(ExeName).make_preferred()}; + std::filesystem::permissions( + ExePath, + std::filesystem::perms::owner_exec | std::filesystem::perms::group_exec | std::filesystem::perms::others_exec, + std::filesystem::perm_options::add); +# endif + } + + for (auto& It : WorkerDescription["dirs"sv]) + { + std::string_view Name = It.AsString(); + std::filesystem::path DirPath{SandboxPath / std::filesystem::path(Name).make_preferred()}; + + // Validate dir path stays within sandbox + { + std::filesystem::path CanonicalRoot = std::filesystem::weakly_canonical(SandboxPath); + std::filesystem::path CanonicalDir = std::filesystem::weakly_canonical(DirPath); + std::string RootStr = CanonicalRoot.string(); + std::string DirStr = CanonicalDir.string(); + + if (DirStr.size() < RootStr.size() || DirStr.compare(0, RootStr.size(), RootStr) != 0) + { + throw zen::runtime_error("path traversal detected: dir '{}' escapes sandbox root '{}'", Name, SandboxPath); + } + } + + zen::CreateDirectories(DirPath); + } + + for (auto& It : WorkerDescription["files"sv]) + { + DecompressAttachmentToFile(WorkerPackage, It.AsObjectView(), SandboxPath, ChunkReferenceCallback); + } + + WriteFile(SandboxPath / "worker.zcb", WorkerDescription.GetBuffer().AsIoBuffer()); +} + +CbPackage +LocalProcessRunner::GatherActionOutputs(std::filesystem::path SandboxPath) +{ + ZEN_TRACE_CPU("LocalProcessRunner::GatherActionOutputs"); + std::filesystem::path OutputFile = SandboxPath / "build.output"; + FileContents OutputData = zen::ReadFile(OutputFile); + + if (OutputData.ErrorCode) + { + throw std::system_error(OutputData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputFile)); + } + + CbPackage OutputPackage; + CbObject Output = zen::LoadCompactBinaryObject(OutputData.Flatten()); + + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalRawAttachmentBytes = 0; + + Output.IterateAttachments([&](CbFieldView Field) { + IoHash Hash = Field.AsHash(); + std::filesystem::path OutputPath{SandboxPath / "Outputs" / Hash.ToHexString()}; + FileContents ChunkData = zen::ReadFile(OutputPath); + + if (ChunkData.ErrorCode) + { + throw std::system_error(ChunkData.ErrorCode, fmt::format("Failed to read build output file '{}'", OutputPath)); + } + + uint64_t ChunkDataRawSize = 0; + IoHash ChunkDataHash; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(ChunkData.Flatten()), ChunkDataHash, ChunkDataRawSize); + + if (!AttachmentBuffer) + { + throw std::runtime_error("Invalid output encountered (not valid CompressedBuffer format)"); + } + + TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + TotalRawAttachmentBytes += ChunkDataRawSize; + + CbAttachment Attachment(std::move(AttachmentBuffer), ChunkDataHash); + OutputPackage.AddAttachment(Attachment); + }); + + OutputPackage.SetObject(Output); + + ZEN_DEBUG("Action completed with {} attachments ({} compressed, {} uncompressed)", + OutputPackage.GetAttachments().size(), + NiceBytes(TotalAttachmentBytes), + NiceBytes(TotalRawAttachmentBytes)); + + return OutputPackage; +} + +void +LocalProcessRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("LocalProcessRunner_Monitor"); + + auto _ = MakeGuard([&] { ZEN_INFO("monitor thread exiting"); }); + + do + { + // On Windows it's possible to wait on process handles, so we wait for either a process to exit + // or for the monitor event to be signaled (which indicates we should check for cancellation + // or shutdown). This could be further improved by using a completion port and registering process + // handles with it, but this is a reasonable first implementation given that we shouldn't be dealing + // with an enormous number of concurrent processes. + // + // On other platforms we just wait on the monitor event and poll for process exits at intervals. +# if ZEN_PLATFORM_WINDOWS + auto WaitOnce = [&] { + HANDLE WaitHandles[MAXIMUM_WAIT_OBJECTS]; + + uint32_t NumHandles = 0; + + WaitHandles[NumHandles++] = m_MonitorThreadEvent.GetWindowsHandle(); + + m_RunningLock.WithSharedLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd && NumHandles < MAXIMUM_WAIT_OBJECTS; ++It) + { + Ref Action = It->second; + + WaitHandles[NumHandles++] = Action->ProcessHandle; + } + }); + + DWORD WaitResult = WaitForMultipleObjects(NumHandles, WaitHandles, FALSE, 1000); + + // return true if a handle was signaled + return (WaitResult <= NumHandles); + }; +# else + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(1000); }; +# endif + + while (!WaitOnce()) + { + if (m_MonitorThreadEnabled == false) + { + return; + } + + SweepRunningActions(); + SampleRunningProcessCpu(); + } + + // Signal received + + SweepRunningActions(); + SampleRunningProcessCpu(); + } while (m_MonitorThreadEnabled); +} + +void +LocalProcessRunner::CancelRunningActions() +{ + // Base class is not directly usable — platform subclasses override this +} + +void +LocalProcessRunner::SampleRunningProcessCpu() +{ + static constexpr uint64_t kSampleIntervalMs = 5'000; + + m_RunningLock.WithSharedLock([&] { + const uint64_t Now = GetHifreqTimerValue(); + for (auto& [Lsn, Running] : m_RunningMap) + { + const bool NeverSampled = Running->LastCpuSampleTicks == 0; + const bool IntervalElapsed = Stopwatch::GetElapsedTimeMs(Now - Running->LastCpuSampleTicks) >= kSampleIntervalMs; + if (NeverSampled || IntervalElapsed) + { + SampleProcessCpu(*Running); + } + } + }); +} + +void +LocalProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("LocalProcessRunner::SweepRunningActions"); +} + +void +LocalProcessRunner::ProcessCompletedActions(std::vector>& CompletedActions) +{ + ZEN_TRACE_CPU("LocalProcessRunner::ProcessCompletedActions"); + // Shared post-processing: gather outputs, set state, clean sandbox. + // Note that this must be called without holding any local locks + // otherwise we may end up with deadlocks. + + for (Ref Running : CompletedActions) + { + const int ActionLsn = Running->Action->ActionLsn; + + if (Running->ExitCode == 0) + { + try + { + // Gather outputs + + CbPackage OutputPackage = GatherActionOutputs(Running->SandboxPath); + + Running->Action->SetResult(std::move(OutputPackage)); + Running->Action->SetActionState(RunnerAction::State::Completed); + + // Enqueue sandbox for deferred background deletion, giving + // file handles time to close before we attempt removal. + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + + // Success -- continue with next iteration of the loop + continue; + } + catch (std::exception& Ex) + { + ZEN_ERROR("Encountered failure while gathering outputs for action lsn {}, '{}'", ActionLsn, Ex.what()); + } + } + + // Failed - clean up the sandbox in the background. + + m_DeferredDeleter.Enqueue(ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h new file mode 100644 index 000000000..7493e980b --- /dev/null +++ b/src/zencompute/runners/localrunner.h @@ -0,0 +1,138 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include +# include +# include +# include +# include + +# include "deferreddeleter.h" + +# include + +# include +# include +# include +# include + +namespace zen { +class CbPackage; +} + +namespace zen::compute { + +/** Direct process spawner + + This runner simply sets up a directory structure for each job and + creates a process to perform the computation in it. It is not very + efficient and is intended mostly for testing. + + */ + +class LocalProcessRunner : public FunctionRunner +{ + LocalProcessRunner(LocalProcessRunner&&) = delete; + LocalProcessRunner& operator=(LocalProcessRunner&&) = delete; + +public: + LocalProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + int32_t MaxConcurrentActions = 0); + ~LocalProcessRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; + [[nodiscard]] virtual bool IsHealthy() override { return true; } + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; + +protected: + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + + struct RunningAction : public RefCounted + { + Ref Action; + void* ProcessHandle = nullptr; + int ExitCode = 0; + std::filesystem::path SandboxPath; + + // State for periodic CPU usage sampling + uint64_t LastCpuSampleTicks = 0; // hifreq timer value at last sample + uint64_t LastCpuOsTicks = 0; // OS CPU ticks (platform-specific units) at last sample + }; + + std::atomic_bool m_AcceptNewActions; + ChunkResolver& m_ChunkResolver; + RwLock m_WorkerLock; + std::filesystem::path m_WorkerPath; + std::atomic m_SandboxCounter = 0; + std::filesystem::path m_SandboxPath; + int32_t m_MaxRunningActions = 64; // arbitrary limit for testing + + // if used in conjuction with m_ResultsLock, this lock must be taken *after* + // m_ResultsLock to avoid deadlocks + RwLock m_RunningLock; + std::unordered_map> m_RunningMap; + + std::atomic m_SubmittingCount = 0; + DeferredDirectoryDeleter& m_DeferredDeleter; + WorkerThreadPool& m_WorkerPool; + + std::thread m_MonitorThread; + std::atomic m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + virtual void SweepRunningActions(); + virtual void CancelRunningActions(); + + // Sample CPU usage for all currently running processes (throttled per-action). + void SampleRunningProcessCpu(); + + // Override in platform runners to sample one process. Called under a shared RunningLock. + virtual void SampleProcessCpu(RunningAction& /*Running*/) {} + + // Shared preamble for SubmitAction: capacity check, sandbox creation, + // worker manifesting, action writing, input manifesting. + struct PreparedAction + { + int32_t ActionLsn; + std::filesystem::path SandboxPath; + std::filesystem::path WorkerPath; + CbPackage WorkerPackage; + }; + std::optional PrepareActionSubmission(Ref Action); + + // Shared post-processing for SweepRunningActions: gather outputs, + // set state, clean sandbox. + void ProcessCompletedActions(std::vector>& CompletedActions); + + std::filesystem::path CreateNewSandbox(); + void ManifestWorker(const CbPackage& WorkerPackage, + const std::filesystem::path& SandboxPath, + std::function&& ChunkReferenceCallback); + std::filesystem::path ManifestWorker(const WorkerDesc& Worker); + CbPackage GatherActionOutputs(std::filesystem::path SandboxPath); + + void DecompressAttachmentToFile(const CbPackage& FromPackage, + CbObjectView FileEntry, + const std::filesystem::path& SandboxRootPath, + std::function& ChunkReferenceCallback); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.cpp b/src/zencompute/runners/macrunner.cpp new file mode 100644 index 000000000..5cec90699 --- /dev/null +++ b/src/zencompute/runners/macrunner.cpp @@ -0,0 +1,491 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "macrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include +# include + +namespace zen::compute { + +using namespace std::literals; + +namespace { + + // All helper functions in this namespace are async-signal-safe (safe to call + // between fork() and execve()). They use only raw syscalls and avoid any + // heap allocation, stdio, or other non-AS-safe operations. + + void WriteToFd(int Fd, const char* Buf, size_t Len) + { + while (Len > 0) + { + ssize_t Written = write(Fd, Buf, Len); + if (Written <= 0) + { + break; + } + Buf += Written; + Len -= static_cast(Written); + } + } + + [[noreturn]] void WriteErrorAndExit(int ErrorPipeFd, const char* Msg, int Errno) + { + // Write the message prefix + size_t MsgLen = 0; + for (const char* P = Msg; *P; ++P) + { + ++MsgLen; + } + WriteToFd(ErrorPipeFd, Msg, MsgLen); + + // Append ": " and the errno string if non-zero + if (Errno != 0) + { + WriteToFd(ErrorPipeFd, ": ", 2); + const char* ErrStr = strerror(Errno); + size_t ErrLen = 0; + for (const char* P = ErrStr; *P; ++P) + { + ++ErrLen; + } + WriteToFd(ErrorPipeFd, ErrStr, ErrLen); + } + + _exit(127); + } + + // Build a Seatbelt profile string that denies everything by default and + // allows only the minimum needed for the worker to execute: process ops, + // system library reads, worker directory (read-only), and sandbox directory + // (read-write). Network access is denied implicitly by the deny-default policy. + std::string BuildSandboxProfile(const std::string& SandboxPath, const std::string& WorkerPath) + { + std::string Profile; + Profile.reserve(1024); + + Profile += "(version 1)\n"; + Profile += "(deny default)\n"; + Profile += "(allow process*)\n"; + Profile += "(allow sysctl-read)\n"; + Profile += "(allow file-read-metadata)\n"; + + // System library paths needed for dynamic linker and runtime + Profile += "(allow file-read* (subpath \"/usr\"))\n"; + Profile += "(allow file-read* (subpath \"/System\"))\n"; + Profile += "(allow file-read* (subpath \"/Library\"))\n"; + Profile += "(allow file-read* (subpath \"/dev\"))\n"; + Profile += "(allow file-read* (subpath \"/private/var/db/dyld\"))\n"; + Profile += "(allow file-read* (subpath \"/etc\"))\n"; + + // Worker directory: read-only + Profile += "(allow file-read* (subpath \""; + Profile += WorkerPath; + Profile += "\"))\n"; + + // Sandbox directory: read+write + Profile += "(allow file-read* file-write* (subpath \""; + Profile += SandboxPath; + Profile += "\"))\n"; + + return Profile; + } + +} // anonymous namespace + +MacProcessRunner::MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); + + if (m_Sandboxed) + { + ZEN_INFO("Seatbelt sandboxing enabled for child processes"); + } +} + +SubmitResult +MacProcessRunner::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("MacProcessRunner::SubmitAction"); + std::optional Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string BuildArg = "-Build=build.action"; + + std::vector ArgV; + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing: {} {} (sandboxed={})", ExePathStr, BuildArg, m_Sandboxed); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + std::string WorkerPathStr = Prepared->WorkerPath.string(); + + // Pre-fork: build sandbox profile and create error pipe + std::string SandboxProfile; + int ErrorPipe[2] = {-1, -1}; + + if (m_Sandboxed) + { + SandboxProfile = BuildSandboxProfile(SandboxPathStr, WorkerPathStr); + + if (pipe(ErrorPipe) != 0) + { + throw zen::runtime_error("pipe() for sandbox error pipe failed: {}", strerror(errno)); + } + fcntl(ErrorPipe[0], F_SETFD, FD_CLOEXEC); + fcntl(ErrorPipe[1], F_SETFD, FD_CLOEXEC); + } + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + int SavedErrno = errno; + if (m_Sandboxed) + { + close(ErrorPipe[0]); + close(ErrorPipe[1]); + } + throw zen::runtime_error("fork() failed: {}", strerror(SavedErrno)); + } + + if (ChildPid == 0) + { + // Child process + + if (m_Sandboxed) + { + // Close read end of error pipe — child only writes + close(ErrorPipe[0]); + + // Apply Seatbelt sandbox profile + char* ErrorBuf = nullptr; + if (sandbox_init(SandboxProfile.c_str(), 0, &ErrorBuf) != 0) + { + // sandbox_init failed — write error to pipe and exit + if (ErrorBuf) + { + WriteErrorAndExit(ErrorPipe[1], ErrorBuf, 0); + // WriteErrorAndExit does not return, but sandbox_free_error + // is not needed since we _exit + } + WriteErrorAndExit(ErrorPipe[1], "sandbox_init failed", errno); + } + if (ErrorBuf) + { + sandbox_free_error(ErrorBuf); + } + + if (chdir(SandboxPathStr.c_str()) != 0) + { + WriteErrorAndExit(ErrorPipe[1], "chdir to sandbox failed", errno); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + + WriteErrorAndExit(ErrorPipe[1], "execve failed", errno); + } + else + { + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(ExePathStr.c_str(), ArgV.data(), Envp.data()); + _exit(127); + } + } + + // Parent process + + if (m_Sandboxed) + { + // Close write end of error pipe — parent only reads + close(ErrorPipe[1]); + + // Read from error pipe. If execve succeeded, pipe was closed by O_CLOEXEC + // and read returns 0. If setup failed, child wrote an error message. + char ErrBuf[512]; + ssize_t BytesRead = read(ErrorPipe[0], ErrBuf, sizeof(ErrBuf) - 1); + close(ErrorPipe[0]); + + if (BytesRead > 0) + { + // Sandbox setup or execve failed + ErrBuf[BytesRead] = '\0'; + + // Reap the child (it called _exit(127)) + waitpid(ChildPid, nullptr, 0); + + // Clean up the sandbox in the background + m_DeferredDeleter.Enqueue(Action->ActionLsn, std::move(Prepared->SandboxPath)); + + ZEN_ERROR("Sandbox setup failed for action {}: {}", Action->ActionLsn, ErrBuf); + + Action->SetActionState(RunnerAction::State::Failed); + return SubmitResult{.IsAccepted = false}; + } + } + + // Store child pid as void* (same convention as zencore/process.cpp) + + Ref NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast(static_cast(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +MacProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::SweepRunningActions"); + std::vector> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref Running = It->second; + + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +MacProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +MacProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("MacProcessRunner::CancelAction"); + + // Hold the shared lock while sending the signal to prevent the sweep thread + // from reaping the PID (via waitpid) between our lookup and kill(). Without + // the lock held, the PID could be recycled by the kernel and we'd signal an + // unrelated process. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref Target = It->second; + if (!Target->ProcessHandle) + { + return; + } + + pid_t Pid = static_cast(reinterpret_cast(Target->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("CancelAction: kill(SIGTERM) for LSN {} (pid {}) failed: {}", ActionLsn, Pid, strerror(errno)); + return; + } + + ZEN_DEBUG("CancelAction: sent SIGTERM to LSN {} (pid {})", ActionLsn, Pid); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +MacProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + const pid_t Pid = static_cast(reinterpret_cast(Running.ProcessHandle)); + + struct proc_taskinfo Info; + if (proc_pidinfo(Pid, PROC_PIDTASKINFO, 0, &Info, sizeof(Info)) <= 0) + { + return; + } + + // pti_total_user and pti_total_system are in nanoseconds + const uint64_t CurrentOsTicks = Info.pti_total_user + Info.pti_total_system; + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): ns → seconds + Running.Action->CpuSeconds.store(static_cast(static_cast(CurrentOsTicks) / 1'000'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // ns → ms: divide by 1,000,000; then as percent of elapsed ms + const float CpuPct = static_cast(static_cast(DeltaOsTicks) / 1'000'000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/macrunner.h b/src/zencompute/runners/macrunner.h new file mode 100644 index 000000000..d653b923a --- /dev/null +++ b/src/zencompute/runners/macrunner.h @@ -0,0 +1,43 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_MAC + +namespace zen::compute { + +/** Native macOS process runner for executing Mac worker executables directly. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using macOS Seatbelt + (sandbox_init): no network access and no filesystem access outside the + explicitly allowed sandbox and worker directories. This requires no elevation. + */ +class MacProcessRunner : public LocalProcessRunner +{ +public: + MacProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + bool m_Sandboxed = false; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp new file mode 100644 index 000000000..672636d06 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -0,0 +1,618 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remotehttprunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +////////////////////////////////////////////////////////////////////////// + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("http_exec")) +, m_ChunkResolver{InChunkResolver} +, m_WorkerPool{InWorkerPool} +, m_HostName{HostName} +, m_BaseUrl{fmt::format("{}/compute", HostName)} +, m_Http(m_BaseUrl) +, m_InstanceId(Oid::NewOid()) +{ + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; +} + +RemoteHttpRunner::~RemoteHttpRunner() +{ + Shutdown(); +} + +void +RemoteHttpRunner::Shutdown() +{ + // TODO: should cleanly drain/cancel pending work + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +void +RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + CbPackage WorkerDesc = WorkerPackage; + + std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); + + HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); + + if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) + { + HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); + + if (DescResponse.StatusCode == HttpResponseCode::NotFound) + { + CbPackage Pkg = WorkerDesc; + + // Build response package by sending only the attachments + // the other end needs. We start with the full package and + // remove the attachments which are not needed. + + { + std::unordered_set Needed; + + CbObject Response = DescResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + Needed.insert(NeedHash); + } + + std::unordered_set ToRemove; + + for (const CbAttachment& Attachment : Pkg.GetAttachments()) + { + const IoHash& Hash = Attachment.GetHash(); + + if (Needed.find(Hash) == Needed.end()) + { + ToRemove.insert(Hash); + } + } + + for (const IoHash& Hash : ToRemove) + { + int RemovedCount = Pkg.RemoveAttachment(Hash); + + ZEN_ASSERT(RemovedCount == 1); + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); + + if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + } + else if (!IsHttpSuccessCode(DescResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + else + { + ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); + } + } + else if (WorkerResponse.StatusCode == HttpResponseCode::OK) + { + // Already known from a previous run + } + else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", + WorkerId, + m_Http.GetBaseUri(), + WorkerUrl, + (int)WorkerResponse.StatusCode, + ToString(WorkerResponse.StatusCode)); + + // TODO: propagate error + } +} + +size_t +RemoteHttpRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + size_t RunningCount = m_RemoteRunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector +RemoteHttpRunner::SubmitActions(const std::vector>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActions"); + + if (Actions.size() <= 1) + { + std::vector Results; + + for (const Ref& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; + } + + // For larger batches, submit HTTP requests in parallel via the shared worker pool + + std::vector> Futures; + Futures.reserve(Actions.size()); + + for (const Ref& Action : Actions) + { + std::packaged_task Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + +SubmitResult +RemoteHttpRunner::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitAction"); + + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + Action->ExecutionLocation = m_HostName; + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + // Determine the submission URL. If the action belongs to a queue, ensure a + // corresponding remote queue exists on the target node and submit via it. + + std::string SubmitUrl = "/jobs"; + if (const int QueueId = Action->QueueId; QueueId != 0) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + if (Oid Token = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + // Enqueue job. If the remote returns FailedDependency (424), it means it + // cannot resolve the worker/function — re-register the worker and retry once. + + CbObject Result; + HttpClient::Response WorkResponse; + HttpResponseCode WorkResponseCode{}; + + for (int Attempt = 0; Attempt < 2; ++Attempt) + { + WorkResponse = m_Http.Post(SubmitUrl, ActionObj); + WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::FailedDependency && Attempt == 0) + { + ZEN_WARN("remote {} returned FailedDependency for action {} — re-registering worker and retrying", + m_Http.GetBaseUri(), + ActionId); + + RegisterWorker(Action->Worker.Descriptor); + } + else + { + break; + } + } + + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (WorkResponseCode == HttpResponseCode::NotFound) + { + // Not all attachments are present + + // Build response package including all required attachments + + CbPackage Pkg; + Pkg.SetObject(ActionObj); + + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + // No such attachment + + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(SubmitUrl, Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}{}", ActionId, m_Http.GetBaseUri(), SubmitUrl); + + // TODO: include more information about the failure in the response + + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (PayloadResponse.StatusCode == HttpResponseCode::OK) + { + Result = PayloadResponse.AsObject(); + } + else + { + // Unexpected response + + const int ResponseStatusCode = (int)PayloadResponse.StatusCode; + + ZEN_WARN("unable to register payloads for action {} at {}{} (error: {} {})", + ActionId, + m_Http.GetBaseUri(), + SubmitUrl, + ResponseStatusCode, + ToString(ResponseStatusCode)); + + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}{}", + ResponseStatusCode, + ToString(ResponseStatusCode), + m_Http.GetBaseUri(), + SubmitUrl)}; + } + } + + if (Result) + { + if (const int32_t LsnField = Result["lsn"].AsInt32(0)) + { + HttpRunningAction NewAction; + NewAction.Action = Action; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; + } + } + + return {}; +} + +Oid +RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) +{ + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + return It->second; + } + } + + // Build a stable idempotency key that uniquely identifies this (runner instance, local queue) + // pair. The server uses this to return the same remote queue token for concurrent or redundant + // requests, preventing orphaned remote queues when multiple threads race through here. + // Also send hostname so the server can associate the queue with its origin for diagnostics. + CbObjectWriter Body; + Body << "idempotency_key"sv << fmt::format("{}/{}", m_InstanceId, QueueId); + Body << "hostname"sv << GetMachineName(); + if (Metadata) + { + Body << "metadata"sv << Metadata; + } + if (Config) + { + Body << "config"sv << Config; + } + + HttpClient::Response Resp = m_Http.Post("/queues/remote", Body.Save()); + if (!Resp) + { + ZEN_WARN("failed to create remote queue for local queue {} on {}", QueueId, m_HostName); + return Oid::Zero; + } + + Oid Token = Oid::TryFromHexString(Resp.AsObject()["queue_token"sv].AsString()); + if (Token == Oid::Zero) + { + return Oid::Zero; + } + + ZEN_DEBUG("created remote queue '{}' for local queue {} on {}", Token, QueueId, m_HostName); + + RwLock::ExclusiveLockScope _(m_QueueTokenLock); + auto [It, Inserted] = m_RemoteQueueTokens.try_emplace(QueueId, Token); + return It->second; +} + +void +RemoteHttpRunner::CancelRemoteQueue(int QueueId) +{ + Oid Token; + { + RwLock::SharedLockScope _(m_QueueTokenLock); + if (auto It = m_RemoteQueueTokens.find(QueueId); It != m_RemoteQueueTokens.end()) + { + Token = It->second; + } + } + + if (Token == Oid::Zero) + { + return; + } + + HttpClient::Response Resp = m_Http.Delete(fmt::format("/queues/{}", Token)); + + if (Resp.StatusCode == HttpResponseCode::NoContent) + { + ZEN_DEBUG("cancelled remote queue '{}' (local queue {}) on {}", Token, QueueId, m_HostName); + } + else + { + ZEN_WARN("failed to cancel remote queue '{}' on {}: {}", Token, m_HostName, int(Resp.StatusCode)); + } +} + +bool +RemoteHttpRunner::IsHealthy() +{ + if (HttpClient::Response Ready = m_Http.Get("/ready")) + { + return true; + } + else + { + // TODO: use response to propagate context + return false; + } +} + +size_t +RemoteHttpRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RemoteRunningMap.size(); +} + +void +RemoteHttpRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("RemoteHttpRunner_Monitor"); + + do + { + const int NormalWaitingTime = 200; + int WaitTimeMs = NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { + const size_t RetiredCount = SweepRunningActions(); + + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) + { + WaitTimeMs = NormalWaitingTime / 4; + } + else + { + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } + } + }); + }; + + while (!WaitOnce()) + { + SweepOnce(); + } + + // Signal received - this may mean we should quit + + SweepOnce(); + } while (m_MonitorThreadEnabled); +} + +size_t +RemoteHttpRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SweepRunningActions"); + std::vector CompletedActions; + + // Poll remote for list of completed actions + + HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); + + if (CbObject Completed = ResponseCompleted.AsObject()) + { + for (auto& FieldIt : Completed["completed"sv]) + { + CbObjectView EntryObj = FieldIt.AsObjectView(); + const int32_t CompleteLsn = EntryObj["lsn"sv].AsInt32(); + std::string_view StateName = EntryObj["state"sv].AsString(); + + RunnerAction::State RemoteState = RunnerAction::FromString(StateName); + + // Always fetch to drain the result from the remote's results map, + // but only keep the result package for successfully completed actions. + HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)); + + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.RemoteState = RemoteState; + + if (RemoteState == RunnerAction::State::Completed && ResponseJob) + { + CompletedAction.ActionResults = ResponseJob.AsPackage(); + } + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); + } + + if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) + { + // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) + if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) + { + const int32_t NewCap = zen::Max(4, CpuCount); + + if (m_MaxRunningActions > NewCap) + { + ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); + + m_MaxRunningActions = NewCap; + } + } + } + } + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (HttpRunningAction& HttpAction : CompletedActions) + { + const int ActionLsn = HttpAction.Action->ActionLsn; + + ZEN_DEBUG("action {} LSN {} (remote LSN {}) -> {}", + HttpAction.Action->ActionId, + ActionLsn, + HttpAction.RemoteActionLsn, + RunnerAction::ToString(HttpAction.RemoteState)); + + if (HttpAction.RemoteState == RunnerAction::State::Completed) + { + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); + } + + HttpAction.Action->SetActionState(HttpAction.RemoteState); + } + + return CompletedActions.size(); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h new file mode 100644 index 000000000..9119992a9 --- /dev/null +++ b/src/zencompute/runners/remotehttprunner.h @@ -0,0 +1,100 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencompute/computeservice.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "functionrunner.h" + +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include + +namespace zen { +class CidStore; +} + +namespace zen::compute { + +/** HTTP-based runner + + This implements a DDC remote compute execution strategy via REST API + + */ + +class RemoteHttpRunner : public FunctionRunner +{ + RemoteHttpRunner(RemoteHttpRunner&&) = delete; + RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; + +public: + RemoteHttpRunner(ChunkResolver& InChunkResolver, + const std::filesystem::path& BaseDir, + std::string_view HostName, + WorkerThreadPool& InWorkerPool); + ~RemoteHttpRunner(); + + virtual void Shutdown() override; + virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual SubmitResult SubmitAction(Ref Action) override; + [[nodiscard]] virtual bool IsHealthy() override; + [[nodiscard]] virtual size_t GetSubmittedActionCount() override; + [[nodiscard]] virtual size_t QueryCapacity() override; + [[nodiscard]] virtual std::vector SubmitActions(const std::vector>& Actions) override; + virtual void CancelRemoteQueue(int QueueId) override; + + std::string_view GetHostName() const { return m_HostName; } + +protected: + LoggerRef Log() { return m_Log; } + +private: + LoggerRef m_Log; + ChunkResolver& m_ChunkResolver; + WorkerThreadPool& m_WorkerPool; + std::string m_HostName; + std::string m_BaseUrl; + HttpClient m_Http; + + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + + struct HttpRunningAction + { + Ref Action; + int RemoteActionLsn = 0; // Remote LSN + RunnerAction::State RemoteState = RunnerAction::State::Failed; + CbPackage ActionResults; + }; + + RwLock m_RunningLock; + std::unordered_map m_RemoteRunningMap; // Note that this is keyed on the *REMOTE* lsn + + std::thread m_MonitorThread; + std::atomic m_MonitorThreadEnabled{true}; + Event m_MonitorThreadEvent; + void MonitorThreadFunction(); + size_t SweepRunningActions(); + + RwLock m_QueueTokenLock; + std::unordered_map m_RemoteQueueTokens; // local QueueId → remote queue token + + // Stable identity for this runner instance, used as part of the idempotency key when + // creating remote queues. Generated once at construction and never changes. + Oid m_InstanceId; + + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.cpp b/src/zencompute/runners/windowsrunner.cpp new file mode 100644 index 000000000..e9a1ae8b6 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.cpp @@ -0,0 +1,460 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "windowsrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +# include +# include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::compute { + +using namespace std::literals; + +WindowsProcessRunner::WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed, + int32_t MaxConcurrentActions) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool, MaxConcurrentActions) +, m_Sandboxed(Sandboxed) +{ + if (!m_Sandboxed) + { + return; + } + + // Build a unique profile name per process to avoid collisions + m_AppContainerName = L"zenserver-sandbox-" + std::to_wstring(GetCurrentProcessId()); + + // Clean up any stale profile from a previous crash + DeleteAppContainerProfile(m_AppContainerName.c_str()); + + PSID Sid = nullptr; + + HRESULT Hr = CreateAppContainerProfile(m_AppContainerName.c_str(), + m_AppContainerName.c_str(), // display name + m_AppContainerName.c_str(), // description + nullptr, // no capabilities + 0, // capability count + &Sid); + + if (FAILED(Hr)) + { + throw zen::runtime_error("CreateAppContainerProfile failed: HRESULT 0x{:08X}", static_cast(Hr)); + } + + m_AppContainerSid = Sid; + + ZEN_INFO("AppContainer sandboxing enabled for child processes (profile={})", WideToUtf8(m_AppContainerName)); +} + +WindowsProcessRunner::~WindowsProcessRunner() +{ + if (m_AppContainerSid) + { + FreeSid(m_AppContainerSid); + m_AppContainerSid = nullptr; + } + + if (!m_AppContainerName.empty()) + { + DeleteAppContainerProfile(m_AppContainerName.c_str()); + } +} + +void +WindowsProcessRunner::GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask) +{ + PACL ExistingDacl = nullptr; + PSECURITY_DESCRIPTOR SecurityDescriptor = nullptr; + + DWORD Result = GetNamedSecurityInfoW(Path.c_str(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + &ExistingDacl, + nullptr, + &SecurityDescriptor); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("GetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $0 = MakeGuard([&] { LocalFree(SecurityDescriptor); }); + + EXPLICIT_ACCESSW Access{}; + Access.grfAccessPermissions = AccessMask; + Access.grfAccessMode = SET_ACCESS; + Access.grfInheritance = OBJECT_INHERIT_ACE | CONTAINER_INHERIT_ACE; + Access.Trustee.TrusteeForm = TRUSTEE_IS_SID; + Access.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + Access.Trustee.ptstrName = static_cast(m_AppContainerSid); + + PACL NewDacl = nullptr; + + Result = SetEntriesInAclW(1, &Access, ExistingDacl, &NewDacl); + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetEntriesInAclW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } + + auto $1 = MakeGuard([&] { LocalFree(NewDacl); }); + + Result = SetNamedSecurityInfoW(const_cast(Path.c_str()), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION, + nullptr, + nullptr, + NewDacl, + nullptr); + + if (Result != ERROR_SUCCESS) + { + throw zen::runtime_error("SetNamedSecurityInfoW failed for '{}': {}", Path.string(), GetSystemErrorAsString(Result)); + } +} + +SubmitResult +WindowsProcessRunner::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SubmitAction"); + std::optional Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Set up environment variables + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + StringBuilder<1024> EnvironmentBlock; + + for (auto& It : WorkerDescription["environment"sv]) + { + EnvironmentBlock.Append(It.AsString()); + EnvironmentBlock.Append('\0'); + } + EnvironmentBlock.Append('\0'); + EnvironmentBlock.Append('\0'); + + // Execute process - this spawns the child process immediately without waiting + // for completion + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath).make_preferred(); + + ExtendableWideStringBuilder<512> CommandLine; + CommandLine.Append(L'"'); + CommandLine.Append(ExePath.c_str()); + CommandLine.Append(L'"'); + CommandLine.Append(L" -Build=build.action"); + + LPSECURITY_ATTRIBUTES lpProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES lpThreadAttributes = nullptr; + BOOL bInheritHandles = FALSE; + DWORD dwCreationFlags = 0; + + ZEN_DEBUG("Executing: {} (sandboxed={})", WideToUtf8(CommandLine.c_str()), m_Sandboxed); + + CommandLine.EnsureNulTerminated(); + + PROCESS_INFORMATION ProcessInformation{}; + + if (m_Sandboxed) + { + // Grant AppContainer access to sandbox and worker directories + GrantAppContainerAccess(Prepared->SandboxPath, FILE_ALL_ACCESS); + GrantAppContainerAccess(Prepared->WorkerPath, FILE_GENERIC_READ | FILE_GENERIC_EXECUTE); + + // Set up extended startup info with AppContainer security capabilities + SECURITY_CAPABILITIES SecurityCapabilities{}; + SecurityCapabilities.AppContainerSid = m_AppContainerSid; + SecurityCapabilities.Capabilities = nullptr; + SecurityCapabilities.CapabilityCount = 0; + + SIZE_T AttrListSize = 0; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = static_cast(malloc(AttrListSize)); + auto $0 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + zen::ThrowLastError("InitializeProcThreadAttributeList failed"); + } + + auto $1 = MakeGuard([&] { DeleteProcThreadAttributeList(AttrList); }); + + if (!UpdateProcThreadAttribute(AttrList, + 0, + PROC_THREAD_ATTRIBUTE_SECURITY_CAPABILITIES, + &SecurityCapabilities, + sizeof(SecurityCapabilities), + nullptr, + nullptr)) + { + zen::ThrowLastError("UpdateProcThreadAttribute (SECURITY_CAPABILITIES) failed"); + } + + STARTUPINFOEXW StartupInfoEx{}; + StartupInfoEx.StartupInfo.cb = sizeof(STARTUPINFOEXW); + StartupInfoEx.lpAttributeList = AttrList; + + dwCreationFlags |= EXTENDED_STARTUPINFO_PRESENT; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfoEx.StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch sandboxed process"); + } + } + else + { + STARTUPINFO StartupInfo{}; + StartupInfo.cb = sizeof StartupInfo; + + BOOL Success = CreateProcessW(nullptr, + CommandLine.Data(), + lpProcessAttributes, + lpThreadAttributes, + bInheritHandles, + dwCreationFlags, + (LPVOID)EnvironmentBlock.Data(), + Prepared->SandboxPath.c_str(), + &StartupInfo, + /* out */ &ProcessInformation); + + if (!Success) + { + zen::ThrowLastError("Unable to launch process"); + } + } + + CloseHandle(ProcessInformation.hThread); + + Ref NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = ProcessInformation.hProcess; + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WindowsProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::SweepRunningActions"); + std::vector> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref Running = It->second; + + DWORD ExitCode = 0; + BOOL IsSuccess = GetExitCodeProcess(Running->ProcessHandle, &ExitCode); + + if (IsSuccess && ExitCode != STILL_ACTIVE) + { + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + Running->ExitCode = ExitCode; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WindowsProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // For expedience we initiate the process termination for all known + // processes before attempting to wait for them to exit. + + // Initiate termination for all known processes before waiting for them to exit. + + for (const auto& Kv : RunningMap) + { + Ref Running = Kv.second; + + BOOL TermSuccess = TerminateProcess(Running->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("TerminateProcess for LSN {} not successful: {}", Running->Action->ActionLsn, GetSystemErrorAsString(LastError)); + } + } + } + + // Wait for all processes and clean up, regardless of whether TerminateProcess succeeded. + + for (auto& [Lsn, Running] : RunningMap) + { + if (Running->ProcessHandle != INVALID_HANDLE_VALUE) + { + DWORD WaitResult = WaitForSingleObject(Running->ProcessHandle, 2000); + + if (WaitResult != WAIT_OBJECT_0) + { + ZEN_WARN("wait for LSN {}: process exit did not succeed, result = {}", Running->Action->ActionLsn, WaitResult); + } + else + { + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + } + + CloseHandle(Running->ProcessHandle); + Running->ProcessHandle = INVALID_HANDLE_VALUE; + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +bool +WindowsProcessRunner::CancelAction(int ActionLsn) +{ + ZEN_TRACE_CPU("WindowsProcessRunner::CancelAction"); + + // Hold the shared lock while terminating to prevent the sweep thread from + // closing the handle between our lookup and TerminateProcess call. + bool Sent = false; + + m_RunningLock.WithSharedLock([&] { + auto It = m_RunningMap.find(ActionLsn); + if (It == m_RunningMap.end()) + { + return; + } + + Ref Target = It->second; + if (Target->ProcessHandle == INVALID_HANDLE_VALUE) + { + return; + } + + BOOL TermSuccess = TerminateProcess(Target->ProcessHandle, 222); + + if (!TermSuccess) + { + DWORD LastError = GetLastError(); + + if (LastError != ERROR_ACCESS_DENIED) + { + ZEN_WARN("CancelAction: TerminateProcess for LSN {} not successful: {}", ActionLsn, GetSystemErrorAsString(LastError)); + } + + return; + } + + ZEN_DEBUG("CancelAction: initiated cancellation of LSN {}", ActionLsn); + Sent = true; + }); + + // The monitor thread will pick up the process exit and mark the action as Failed. + return Sent; +} + +void +WindowsProcessRunner::SampleProcessCpu(RunningAction& Running) +{ + FILETIME CreationTime, ExitTime, KernelTime, UserTime; + if (!GetProcessTimes(Running.ProcessHandle, &CreationTime, &ExitTime, &KernelTime, &UserTime)) + { + return; + } + + auto FtToU64 = [](FILETIME Ft) -> uint64_t { return (static_cast(Ft.dwHighDateTime) << 32) | Ft.dwLowDateTime; }; + + // FILETIME values are in 100-nanosecond intervals + const uint64_t CurrentOsTicks = FtToU64(KernelTime) + FtToU64(UserTime); + const uint64_t NowTicks = GetHifreqTimerValue(); + + // Cumulative CPU seconds (absolute, available from first sample): 100ns → seconds + Running.Action->CpuSeconds.store(static_cast(static_cast(CurrentOsTicks) / 10'000'000.0), std::memory_order_relaxed); + + if (Running.LastCpuSampleTicks != 0 && Running.LastCpuOsTicks != 0) + { + const uint64_t ElapsedMs = Stopwatch::GetElapsedTimeMs(NowTicks - Running.LastCpuSampleTicks); + if (ElapsedMs > 0) + { + const uint64_t DeltaOsTicks = CurrentOsTicks - Running.LastCpuOsTicks; + // 100ns → ms: divide by 10000; then as percent of elapsed ms + const float CpuPct = static_cast(static_cast(DeltaOsTicks) / 10000.0 / ElapsedMs * 100.0); + Running.Action->CpuUsagePercent.store(CpuPct, std::memory_order_relaxed); + } + } + + Running.LastCpuSampleTicks = NowTicks; + Running.LastCpuOsTicks = CurrentOsTicks; +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/windowsrunner.h b/src/zencompute/runners/windowsrunner.h new file mode 100644 index 000000000..9f2385cc4 --- /dev/null +++ b/src/zencompute/runners/windowsrunner.h @@ -0,0 +1,53 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_WINDOWS + +# include + +# include + +namespace zen::compute { + +/** Windows process runner using CreateProcessW for executing worker executables. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + + When Sandboxed is true, child processes are isolated using a Windows AppContainer: + no network access (AppContainer blocks network by default when no capabilities are + granted) and no filesystem access outside explicitly granted sandbox and worker + directories. This requires no elevation. + */ +class WindowsProcessRunner : public LocalProcessRunner +{ +public: + WindowsProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool, + bool Sandboxed = false, + int32_t MaxConcurrentActions = 0); + ~WindowsProcessRunner(); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + bool CancelAction(int ActionLsn) override; + void SampleProcessCpu(RunningAction& Running) override; + +private: + void GrantAppContainerAccess(const std::filesystem::path& Path, DWORD AccessMask); + + bool m_Sandboxed = false; + PSID m_AppContainerSid = nullptr; + std::wstring m_AppContainerName; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.cpp b/src/zencompute/runners/winerunner.cpp new file mode 100644 index 000000000..506bec73b --- /dev/null +++ b/src/zencompute/runners/winerunner.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winerunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include + +namespace zen::compute { + +using namespace std::literals; + +WineProcessRunner::WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool) +: LocalProcessRunner(Resolver, BaseDir, Deleter, WorkerPool) +{ + // Restore SIGCHLD to default behavior so waitpid() can properly collect + // child exit status. zenserver/main.cpp sets SIGCHLD to SIG_IGN which + // causes the kernel to auto-reap children, making waitpid() return + // -1/ECHILD instead of the exit status we need. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_DFL; + sigaction(SIGCHLD, &Action, nullptr); +} + +SubmitResult +WineProcessRunner::SubmitAction(Ref Action) +{ + ZEN_TRACE_CPU("WineProcessRunner::SubmitAction"); + std::optional Prepared = PrepareActionSubmission(Action); + + if (!Prepared) + { + return SubmitResult{.IsAccepted = false}; + } + + // Build environment array from worker descriptor + + CbObject WorkerDescription = Prepared->WorkerPackage.GetObject(); + + std::vector EnvStrings; + for (auto& It : WorkerDescription["environment"sv]) + { + EnvStrings.emplace_back(It.AsString()); + } + + std::vector Envp; + Envp.reserve(EnvStrings.size() + 1); + for (auto& Str : EnvStrings) + { + Envp.push_back(Str.data()); + } + Envp.push_back(nullptr); + + // Build argv: wine -Build=build.action + + std::string_view ExecPath = WorkerDescription["path"sv].AsString(); + std::filesystem::path ExePath = Prepared->WorkerPath / std::filesystem::path(ExecPath); + std::string ExePathStr = ExePath.string(); + std::string WinePathStr = m_WinePath; + std::string BuildArg = "-Build=build.action"; + + std::vector ArgV; + ArgV.push_back(WinePathStr.data()); + ArgV.push_back(ExePathStr.data()); + ArgV.push_back(BuildArg.data()); + ArgV.push_back(nullptr); + + ZEN_DEBUG("Executing via Wine: {} {} {}", WinePathStr, ExePathStr, BuildArg); + + std::string SandboxPathStr = Prepared->SandboxPath.string(); + + pid_t ChildPid = fork(); + + if (ChildPid < 0) + { + throw std::runtime_error(fmt::format("fork() failed: {}", strerror(errno))); + } + + if (ChildPid == 0) + { + // Child process + if (chdir(SandboxPathStr.c_str()) != 0) + { + _exit(127); + } + + execve(WinePathStr.c_str(), ArgV.data(), Envp.data()); + + // execve only returns on failure + _exit(127); + } + + // Parent: store child pid as void* (same convention as zencore/process.cpp) + + Ref NewAction{new RunningAction()}; + NewAction->Action = Action; + NewAction->ProcessHandle = reinterpret_cast(static_cast(ChildPid)); + NewAction->SandboxPath = std::move(Prepared->SandboxPath); + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RunningMap[Prepared->ActionLsn] = std::move(NewAction); + } + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; +} + +void +WineProcessRunner::SweepRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::SweepRunningActions"); + std::vector> CompletedActions; + + m_RunningLock.WithExclusiveLock([&] { + for (auto It = begin(m_RunningMap), ItEnd = end(m_RunningMap); It != ItEnd;) + { + Ref Running = It->second; + + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + int Status = 0; + + pid_t Result = waitpid(Pid, &Status, WNOHANG); + + if (Result == Pid) + { + if (WIFEXITED(Status)) + { + Running->ExitCode = WEXITSTATUS(Status); + } + else if (WIFSIGNALED(Status)) + { + Running->ExitCode = 128 + WTERMSIG(Status); + } + else + { + Running->ExitCode = 1; + } + + Running->ProcessHandle = nullptr; + + CompletedActions.push_back(std::move(Running)); + It = m_RunningMap.erase(It); + } + else + { + ++It; + } + } + }); + + ProcessCompletedActions(CompletedActions); +} + +void +WineProcessRunner::CancelRunningActions() +{ + ZEN_TRACE_CPU("WineProcessRunner::CancelRunningActions"); + Stopwatch Timer; + std::unordered_map> RunningMap; + + m_RunningLock.WithExclusiveLock([&] { std::swap(RunningMap, m_RunningMap); }); + + if (RunningMap.empty()) + { + return; + } + + ZEN_INFO("cancelling all running actions"); + + // Send SIGTERM to all running processes first + + for (const auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + if (kill(Pid, SIGTERM) != 0) + { + ZEN_WARN("kill(SIGTERM) for LSN {} (pid {}) failed: {}", Running->Action->ActionLsn, Pid, strerror(errno)); + } + } + + // Wait for all processes, regardless of whether SIGTERM succeeded, then clean up. + + for (auto& [Lsn, Running] : RunningMap) + { + pid_t Pid = static_cast(reinterpret_cast(Running->ProcessHandle)); + + // Poll for up to 2 seconds + bool Exited = false; + for (int i = 0; i < 20; ++i) + { + int Status = 0; + pid_t WaitResult = waitpid(Pid, &Status, WNOHANG); + if (WaitResult == Pid) + { + Exited = true; + ZEN_DEBUG("LSN {}: process exit OK", Running->Action->ActionLsn); + break; + } + usleep(100000); // 100ms + } + + if (!Exited) + { + ZEN_WARN("LSN {}: process did not exit after SIGTERM, sending SIGKILL", Running->Action->ActionLsn); + kill(Pid, SIGKILL); + waitpid(Pid, nullptr, 0); + } + + m_DeferredDeleter.Enqueue(Running->Action->ActionLsn, std::move(Running->SandboxPath)); + Running->Action->SetActionState(RunnerAction::State::Failed); + } + + ZEN_INFO("DONE - cancelled {} running processes (took {})", RunningMap.size(), NiceTimeSpanMs(Timer.GetElapsedTimeMs())); +} + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/runners/winerunner.h b/src/zencompute/runners/winerunner.h new file mode 100644 index 000000000..7df62e7c0 --- /dev/null +++ b/src/zencompute/runners/winerunner.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "localrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES && ZEN_PLATFORM_LINUX + +# include + +namespace zen::compute { + +/** Wine-based process runner for executing Windows worker executables on Linux. + + Subclasses LocalProcessRunner, reusing sandbox management, worker manifesting, + input/output handling, and monitor thread infrastructure. Overrides only the + platform-specific methods: process spawning, sweep, and cancellation. + */ +class WineProcessRunner : public LocalProcessRunner +{ +public: + WineProcessRunner(ChunkResolver& Resolver, + const std::filesystem::path& BaseDir, + DeferredDirectoryDeleter& Deleter, + WorkerThreadPool& WorkerPool); + + [[nodiscard]] SubmitResult SubmitAction(Ref Action) override; + void SweepRunningActions() override; + void CancelRunningActions() override; + +private: + std::string m_WinePath = "wine"; +}; + +} // namespace zen::compute + +#endif diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp new file mode 100644 index 000000000..dd09312df --- /dev/null +++ b/src/zencompute/testing/mockimds.cpp @@ -0,0 +1,205 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zencompute/timeline/workertimeline.cpp b/src/zencompute/timeline/workertimeline.cpp new file mode 100644 index 000000000..88ef5b62d --- /dev/null +++ b/src/zencompute/timeline/workertimeline.cpp @@ -0,0 +1,430 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "workertimeline.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include + +# include + +namespace zen::compute { + +WorkerTimeline::WorkerTimeline(std::string_view WorkerId) : m_WorkerId(WorkerId) +{ +} + +WorkerTimeline::~WorkerTimeline() +{ +} + +void +WorkerTimeline::RecordProvisioned() +{ + AppendEvent({ + .Type = EventType::WorkerProvisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordDeprovisioned() +{ + AppendEvent({ + .Type = EventType::WorkerDeprovisioned, + .Timestamp = DateTime::Now(), + }); +} + +void +WorkerTimeline::RecordActionAccepted(int ActionLsn, const IoHash& ActionId) +{ + AppendEvent({ + .Type = EventType::ActionAccepted, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + }); +} + +void +WorkerTimeline::RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason) +{ + AppendEvent({ + .Type = EventType::ActionRejected, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .Reason = std::string(Reason), + }); +} + +void +WorkerTimeline::RecordActionStateChanged(int ActionLsn, + const IoHash& ActionId, + RunnerAction::State PreviousState, + RunnerAction::State NewState) +{ + AppendEvent({ + .Type = EventType::ActionStateChanged, + .Timestamp = DateTime::Now(), + .ActionLsn = ActionLsn, + .ActionId = ActionId, + .ActionState = NewState, + .PreviousState = PreviousState, + }); +} + +std::vector +WorkerTimeline::QueryTimeline(DateTime StartTime, DateTime EndTime) const +{ + std::vector Result; + + m_EventsLock.WithSharedLock([&] { + for (const auto& Evt : m_Events) + { + if (Evt.Timestamp >= StartTime && Evt.Timestamp <= EndTime) + { + Result.push_back(Evt); + } + } + }); + + return Result; +} + +std::vector +WorkerTimeline::QueryRecent(int Limit) const +{ + std::vector Result; + + m_EventsLock.WithSharedLock([&] { + const int Count = std::min(Limit, gsl::narrow(m_Events.size())); + auto It = m_Events.end() - Count; + Result.assign(It, m_Events.end()); + }); + + return Result; +} + +size_t +WorkerTimeline::GetEventCount() const +{ + size_t Count = 0; + m_EventsLock.WithSharedLock([&] { Count = m_Events.size(); }); + return Count; +} + +WorkerTimeline::TimeRange +WorkerTimeline::GetTimeRange() const +{ + TimeRange Range; + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Range.First = m_Events.front().Timestamp; + Range.Last = m_Events.back().Timestamp; + } + }); + return Range; +} + +void +WorkerTimeline::AppendEvent(Event&& Evt) +{ + m_EventsLock.WithExclusiveLock([&] { + while (m_Events.size() >= m_MaxEvents) + { + m_Events.pop_front(); + } + + m_Events.push_back(std::move(Evt)); + }); +} + +const char* +WorkerTimeline::ToString(EventType Type) +{ + switch (Type) + { + case EventType::WorkerProvisioned: + return "provisioned"; + case EventType::WorkerDeprovisioned: + return "deprovisioned"; + case EventType::ActionAccepted: + return "accepted"; + case EventType::ActionRejected: + return "rejected"; + case EventType::ActionStateChanged: + return "state_changed"; + default: + return "unknown"; + } +} + +static WorkerTimeline::EventType +EventTypeFromString(std::string_view Str) +{ + if (Str == "provisioned") + return WorkerTimeline::EventType::WorkerProvisioned; + if (Str == "deprovisioned") + return WorkerTimeline::EventType::WorkerDeprovisioned; + if (Str == "accepted") + return WorkerTimeline::EventType::ActionAccepted; + if (Str == "rejected") + return WorkerTimeline::EventType::ActionRejected; + if (Str == "state_changed") + return WorkerTimeline::EventType::ActionStateChanged; + return WorkerTimeline::EventType::WorkerProvisioned; +} + +void +WorkerTimeline::WriteTo(const std::filesystem::path& Path) const +{ + CbObjectWriter Cbo; + Cbo << "worker_id" << m_WorkerId; + + m_EventsLock.WithSharedLock([&] { + if (!m_Events.empty()) + { + Cbo.AddDateTime("time_first", m_Events.front().Timestamp); + Cbo.AddDateTime("time_last", m_Events.back().Timestamp); + } + + Cbo.BeginArray("events"); + for (const auto& Evt : m_Events) + { + Cbo.BeginObject(); + Cbo << "type" << ToString(Evt.Type); + Cbo.AddDateTime("ts", Evt.Timestamp); + + if (Evt.ActionLsn != 0) + { + Cbo << "lsn" << Evt.ActionLsn; + Cbo << "action_id" << Evt.ActionId; + } + + if (Evt.Type == EventType::ActionStateChanged) + { + Cbo << "prev_state" << static_cast(Evt.PreviousState); + Cbo << "state" << static_cast(Evt.ActionState); + } + + if (!Evt.Reason.empty()) + { + Cbo << "reason" << std::string_view(Evt.Reason); + } + + Cbo.EndObject(); + } + Cbo.EndArray(); + }); + + CbObject Obj = Cbo.Save(); + + BasicFile File(Path, BasicFile::Mode::kTruncate); + File.Write(Obj.GetBuffer().GetView(), 0); +} + +void +WorkerTimeline::ReadFrom(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + CbObject Root = std::move(Loaded.Object); + + if (!Root) + { + return; + } + + std::deque LoadedEvents; + + for (CbFieldView Field : Root["events"].AsArrayView()) + { + CbObjectView EventObj = Field.AsObjectView(); + + Event Evt; + Evt.Type = EventTypeFromString(EventObj["type"].AsString()); + Evt.Timestamp = EventObj["ts"].AsDateTime(); + + Evt.ActionLsn = EventObj["lsn"].AsInt32(); + Evt.ActionId = EventObj["action_id"].AsHash(); + + if (Evt.Type == EventType::ActionStateChanged) + { + Evt.PreviousState = static_cast(EventObj["prev_state"].AsInt32()); + Evt.ActionState = static_cast(EventObj["state"].AsInt32()); + } + + std::string_view Reason = EventObj["reason"].AsString(); + if (!Reason.empty()) + { + Evt.Reason = std::string(Reason); + } + + LoadedEvents.push_back(std::move(Evt)); + } + + m_EventsLock.WithExclusiveLock([&] { m_Events = std::move(LoadedEvents); }); +} + +WorkerTimeline::TimeRange +WorkerTimeline::ReadTimeRange(const std::filesystem::path& Path) +{ + CbObjectFromFile Loaded = LoadCompactBinaryObject(Path); + + if (!Loaded.Object) + { + return {}; + } + + return { + .First = Loaded.Object["time_first"].AsDateTime(), + .Last = Loaded.Object["time_last"].AsDateTime(), + }; +} + +// WorkerTimelineStore + +static constexpr std::string_view kTimelineExtension = ".ztimeline"; + +WorkerTimelineStore::WorkerTimelineStore(std::filesystem::path PersistenceDir) : m_PersistenceDir(std::move(PersistenceDir)) +{ + std::error_code Ec; + std::filesystem::create_directories(m_PersistenceDir, Ec); +} + +Ref +WorkerTimelineStore::GetOrCreate(std::string_view WorkerId) +{ + // Fast path: check if it already exists in memory + { + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + } + + // Slow path: create under exclusive lock, loading from disk if available + RwLock::ExclusiveLockScope _(m_Lock); + + auto& Entry = m_Timelines[std::string(WorkerId)]; + if (!Entry) + { + Entry = Ref(new WorkerTimeline(WorkerId)); + + std::filesystem::path Path = TimelinePath(WorkerId); + std::error_code Ec; + if (std::filesystem::is_regular_file(Path, Ec)) + { + Entry->ReadFrom(Path); + } + } + return Entry; +} + +Ref +WorkerTimelineStore::Find(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + return It->second; + } + return {}; +} + +std::vector +WorkerTimelineStore::GetActiveWorkerIds() const +{ + std::vector Result; + + RwLock::SharedLockScope $(m_Lock); + Result.reserve(m_Timelines.size()); + for (const auto& [Id, _] : m_Timelines) + { + Result.push_back(Id); + } + + return Result; +} + +std::vector +WorkerTimelineStore::GetAllWorkerInfo() const +{ + std::unordered_map InfoMap; + + { + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + InfoMap[Id] = Timeline->GetTimeRange(); + } + } + + std::error_code Ec; + for (const auto& Entry : std::filesystem::directory_iterator(m_PersistenceDir, Ec)) + { + if (!Entry.is_regular_file()) + { + continue; + } + + const auto& Path = Entry.path(); + if (Path.extension().string() != kTimelineExtension) + { + continue; + } + + std::string Id = Path.stem().string(); + if (InfoMap.find(Id) == InfoMap.end()) + { + InfoMap[Id] = WorkerTimeline::ReadTimeRange(Path); + } + } + + std::vector Result; + Result.reserve(InfoMap.size()); + for (auto& [Id, Range] : InfoMap) + { + Result.push_back({.WorkerId = std::move(Id), .Range = Range}); + } + return Result; +} + +void +WorkerTimelineStore::Save(std::string_view WorkerId) +{ + RwLock::SharedLockScope _(m_Lock); + auto It = m_Timelines.find(std::string(WorkerId)); + if (It != m_Timelines.end()) + { + It->second->WriteTo(TimelinePath(WorkerId)); + } +} + +void +WorkerTimelineStore::SaveAll() +{ + RwLock::SharedLockScope _(m_Lock); + for (const auto& [Id, Timeline] : m_Timelines) + { + Timeline->WriteTo(TimelinePath(Id)); + } +} + +std::filesystem::path +WorkerTimelineStore::TimelinePath(std::string_view WorkerId) const +{ + return m_PersistenceDir / (std::string(WorkerId) + std::string(kTimelineExtension)); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/timeline/workertimeline.h b/src/zencompute/timeline/workertimeline.h new file mode 100644 index 000000000..87e19bc28 --- /dev/null +++ b/src/zencompute/timeline/workertimeline.h @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../runners/functionrunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include +# include + +namespace zen::compute { + +struct RunnerAction; + +/** Worker activity timeline for tracking and visualizing worker activity over time. + * + * Records worker lifecycle events (provisioning/deprovisioning) and action lifecycle + * events (accept, reject, state changes) with timestamps, enabling time-range queries + * for dashboard visualization. + */ +class WorkerTimeline : public RefCounted +{ +public: + explicit WorkerTimeline(std::string_view WorkerId); + ~WorkerTimeline() override; + + struct TimeRange + { + DateTime First = DateTime(0); + DateTime Last = DateTime(0); + + explicit operator bool() const { return First.GetTicks() != 0; } + }; + + enum class EventType + { + WorkerProvisioned, + WorkerDeprovisioned, + ActionAccepted, + ActionRejected, + ActionStateChanged + }; + + static const char* ToString(EventType Type); + + struct Event + { + EventType Type; + DateTime Timestamp = DateTime(0); + + // Action context (only set for action events) + int ActionLsn = 0; + IoHash ActionId; + RunnerAction::State ActionState = RunnerAction::State::New; + RunnerAction::State PreviousState = RunnerAction::State::New; + + // Optional reason (e.g. rejection reason) + std::string Reason; + }; + + /** Record that this worker has been provisioned and is available for work. */ + void RecordProvisioned(); + + /** Record that this worker has been deprovisioned and is no longer available. */ + void RecordDeprovisioned(); + + /** Record that an action was accepted by this worker. */ + void RecordActionAccepted(int ActionLsn, const IoHash& ActionId); + + /** Record that an action was rejected by this worker. */ + void RecordActionRejected(int ActionLsn, const IoHash& ActionId, std::string_view Reason); + + /** Record an action state transition on this worker. */ + void RecordActionStateChanged(int ActionLsn, const IoHash& ActionId, RunnerAction::State PreviousState, RunnerAction::State NewState); + + /** Query events within a time range (inclusive). Returns events ordered by timestamp. */ + [[nodiscard]] std::vector QueryTimeline(DateTime StartTime, DateTime EndTime) const; + + /** Query the most recent N events. */ + [[nodiscard]] std::vector QueryRecent(int Limit = 100) const; + + /** Return the total number of recorded events. */ + [[nodiscard]] size_t GetEventCount() const; + + /** Return the time range covered by the events in this timeline. */ + [[nodiscard]] TimeRange GetTimeRange() const; + + [[nodiscard]] const std::string& GetWorkerId() const { return m_WorkerId; } + + /** Write the timeline to a file at the given path. */ + void WriteTo(const std::filesystem::path& Path) const; + + /** Read the timeline from a file at the given path. Replaces current in-memory events. */ + void ReadFrom(const std::filesystem::path& Path); + + /** Read only the time range from a persisted timeline file, without loading events. */ + [[nodiscard]] static TimeRange ReadTimeRange(const std::filesystem::path& Path); + +private: + void AppendEvent(Event&& Evt); + + std::string m_WorkerId; + mutable RwLock m_EventsLock; + std::deque m_Events; + size_t m_MaxEvents = 10'000; +}; + +/** Manages a set of WorkerTimeline instances, keyed by worker ID. + * + * Provides thread-safe lookup and on-demand creation of timelines, backed by + * a persistence directory. Each timeline is stored as a separate file named + * {WorkerId}.ztimeline within the directory. + */ +class WorkerTimelineStore +{ +public: + explicit WorkerTimelineStore(std::filesystem::path PersistenceDir); + ~WorkerTimelineStore() = default; + + WorkerTimelineStore(const WorkerTimelineStore&) = delete; + WorkerTimelineStore& operator=(const WorkerTimelineStore&) = delete; + + /** Get the timeline for a worker, creating one if it does not exist. + * If a persisted file exists on disk it will be loaded on first access. */ + Ref GetOrCreate(std::string_view WorkerId); + + /** Get the timeline for a worker, or null ref if it does not exist in memory. */ + [[nodiscard]] Ref Find(std::string_view WorkerId); + + /** Return the worker IDs of currently loaded (in-memory) timelines. */ + [[nodiscard]] std::vector GetActiveWorkerIds() const; + + struct WorkerTimelineInfo + { + std::string WorkerId; + WorkerTimeline::TimeRange Range; + }; + + /** Return info for all known timelines (in-memory and on-disk), including time range. */ + [[nodiscard]] std::vector GetAllWorkerInfo() const; + + /** Persist a single worker's timeline to disk. */ + void Save(std::string_view WorkerId); + + /** Persist all in-memory timelines to disk. */ + void SaveAll(); + +private: + [[nodiscard]] std::filesystem::path TimelinePath(std::string_view WorkerId) const; + + std::filesystem::path m_PersistenceDir; + mutable RwLock m_Lock; + std::unordered_map> m_Timelines; +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencompute/xmake.lua b/src/zencompute/xmake.lua index 50877508c..ed0af66a5 100644 --- a/src/zencompute/xmake.lua +++ b/src/zencompute/xmake.lua @@ -6,4 +6,14 @@ target('zencompute') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) + add_includedirs(".", {private=true}) add_deps("zencore", "zenstore", "zenutil", "zennet", "zenhttp") + add_packages("json11") + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end + + if is_plat("windows") then + add_syslinks("Userenv") + end diff --git a/src/zencompute/zencompute.cpp b/src/zencompute/zencompute.cpp index 633250f4e..1f3f6d3f9 100644 --- a/src/zencompute/zencompute.cpp +++ b/src/zencompute/zencompute.cpp @@ -2,11 +2,20 @@ #include "zencompute/zencompute.h" +#if ZEN_WITH_TESTS +# include "runners/deferreddeleter.h" +# include +#endif + namespace zen { void zencompute_forcelinktests() { +#if ZEN_WITH_TESTS + compute::cloudmetadata_forcelink(); + compute::deferreddeleter_forcelink(); +#endif } } // namespace zen diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index bf3c15d3d..fecbe2dbe 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -4,6 +4,8 @@ #include +#include +#include #include namespace zen { @@ -12,6 +14,7 @@ class CbWriter; std::string GetMachineName(); std::string_view GetOperatingSystemName(); +std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos" std::string_view GetCpuName(); struct SystemMetrics @@ -25,7 +28,13 @@ struct SystemMetrics uint64_t AvailVirtualMemoryMiB = 0; uint64_t PageFileMiB = 0; uint64_t AvailPageFileMiB = 0; - float CpuUsagePercent = 0.0f; +}; + +/// Extended metrics that include CPU usage percentage, which requires +/// stateful delta tracking via SystemMetricsTracker. +struct ExtendedSystemMetrics : SystemMetrics +{ + float CpuUsagePercent = 0.0f; }; SystemMetrics GetSystemMetrics(); @@ -33,6 +42,31 @@ SystemMetrics GetSystemMetrics(); void SetCpuCountForReporting(int FakeCpuCount); SystemMetrics GetSystemMetricsForReporting(); +ExtendedSystemMetrics ApplyReportingOverrides(ExtendedSystemMetrics Metrics); + void Describe(const SystemMetrics& Metrics, CbWriter& Writer); +void Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer); + +/// Stateful tracker that computes CPU usage as a delta between consecutive +/// Query() calls. The first call returns CpuUsagePercent = 0 (no previous +/// sample). Thread-safe: concurrent calls are serialised internally. +/// CPU sampling is rate-limited to MinInterval (default 1 s); calls that +/// arrive sooner return the previously cached value. +class SystemMetricsTracker +{ +public: + explicit SystemMetricsTracker(std::chrono::milliseconds MinInterval = std::chrono::seconds(1)); + ~SystemMetricsTracker(); + + SystemMetricsTracker(const SystemMetricsTracker&) = delete; + SystemMetricsTracker& operator=(const SystemMetricsTracker&) = delete; + + /// Collect current metrics. CPU usage is computed as delta since last Query(). + ExtendedSystemMetrics Query(); + +private: + struct Impl; + std::unique_ptr m_Impl; +}; } // namespace zen diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index 267c87e12..833d3c04b 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -7,6 +7,8 @@ #include #include +#include + #if ZEN_PLATFORM_WINDOWS # include @@ -133,33 +135,6 @@ GetSystemMetrics() Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; } - // Query CPU usage using PDH - // - // TODO: This should be changed to not require a Sleep, perhaps by using some - // background metrics gathering mechanism. - - { - PDH_HQUERY QueryHandle = nullptr; - PDH_HCOUNTER CounterHandle = nullptr; - - if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) - { - if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) == ERROR_SUCCESS) - { - PdhCollectQueryData(QueryHandle); - Sleep(100); - PdhCollectQueryData(QueryHandle); - - PDH_FMT_COUNTERVALUE CounterValue; - if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) - { - Metrics.CpuUsagePercent = static_cast(CounterValue.doubleValue); - } - } - PdhCloseQuery(QueryHandle); - } - } - return Metrics; } #elif ZEN_PLATFORM_LINUX @@ -235,39 +210,6 @@ GetSystemMetrics() } } - // Query CPU usage - Metrics.CpuUsagePercent = 0.0f; - if (FILE* Stat = fopen("/proc/stat", "r")) - { - char Line[256]; - unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; - static unsigned long PrevUser = 0, PrevNice = 0, PrevSystem = 0, PrevIdle = 0, PrevIoWait = 0, PrevIrq = 0, PrevSoftIrq = 0; - - if (fgets(Line, sizeof(Line), Stat)) - { - if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) - { - unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - - (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); - unsigned long IdleDelta = Idle - PrevIdle; - - if (TotalDelta > 0) - { - Metrics.CpuUsagePercent = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; - } - - PrevUser = User; - PrevNice = Nice; - PrevSystem = System; - PrevIdle = Idle; - PrevIoWait = IoWait; - PrevIrq = Irq; - PrevSoftIrq = SoftIrq; - } - } - fclose(Stat); - } - // Get memory information long Pages = sysconf(_SC_PHYS_PAGES); long PageSize = sysconf(_SC_PAGE_SIZE); @@ -348,25 +290,6 @@ GetSystemMetrics() sysctlbyname("hw.packages", &Packages, &Size, nullptr, 0); Metrics.CpuCount = Packages > 0 ? Packages : 1; - // Query CPU usage using host_statistics64 - Metrics.CpuUsagePercent = 0.0f; - host_cpu_load_info_data_t CpuLoad; - mach_msg_type_number_t CpuCount = sizeof(CpuLoad) / sizeof(natural_t); - if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &CpuCount) == KERN_SUCCESS) - { - unsigned long TotalTicks = 0; - for (int i = 0; i < CPU_STATE_MAX; ++i) - { - TotalTicks += CpuLoad.cpu_ticks[i]; - } - - if (TotalTicks > 0) - { - unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; - Metrics.CpuUsagePercent = 100.0f * (TotalTicks - IdleTicks) / TotalTicks; - } - } - // Get memory information uint64_t MemSize = 0; Size = sizeof(MemSize); @@ -401,6 +324,17 @@ GetSystemMetrics() # error "Unknown platform" #endif +ExtendedSystemMetrics +ApplyReportingOverrides(ExtendedSystemMetrics Metrics) +{ + if (g_FakeCpuCount) + { + Metrics.CoreCount = g_FakeCpuCount; + Metrics.LogicalProcessorCount = g_FakeCpuCount; + } + return Metrics; +} + SystemMetrics GetSystemMetricsForReporting() { @@ -415,12 +349,249 @@ GetSystemMetricsForReporting() return Sm; } +/////////////////////////////////////////////////////////////////////////// +// SystemMetricsTracker +/////////////////////////////////////////////////////////////////////////// + +// Per-platform CPU sampling helper. Called with m_Mutex held. + +#if ZEN_PLATFORM_WINDOWS || ZEN_PLATFORM_LINUX + +// Samples CPU usage by reading /proc/stat. Used natively on Linux and as a +// Wine fallback on Windows (where /proc/stat is accessible via the Z: drive). +struct ProcStatCpuSampler +{ + const char* Path = "/proc/stat"; + unsigned long PrevUser = 0; + unsigned long PrevNice = 0; + unsigned long PrevSystem = 0; + unsigned long PrevIdle = 0; + unsigned long PrevIoWait = 0; + unsigned long PrevIrq = 0; + unsigned long PrevSoftIrq = 0; + + explicit ProcStatCpuSampler(const char* InPath = "/proc/stat") : Path(InPath) {} + + float Sample() + { + float CpuUsage = 0.0f; + + if (FILE* Stat = fopen(Path, "r")) + { + char Line[256]; + unsigned long User, Nice, System, Idle, IoWait, Irq, SoftIrq; + + if (fgets(Line, sizeof(Line), Stat)) + { + if (sscanf(Line, "cpu %lu %lu %lu %lu %lu %lu %lu", &User, &Nice, &System, &Idle, &IoWait, &Irq, &SoftIrq) == 7) + { + unsigned long TotalDelta = (User + Nice + System + Idle + IoWait + Irq + SoftIrq) - + (PrevUser + PrevNice + PrevSystem + PrevIdle + PrevIoWait + PrevIrq + PrevSoftIrq); + unsigned long IdleDelta = Idle - PrevIdle; + + if (TotalDelta > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevUser = User; + PrevNice = Nice; + PrevSystem = System; + PrevIdle = Idle; + PrevIoWait = IoWait; + PrevIrq = Irq; + PrevSoftIrq = SoftIrq; + } + } + fclose(Stat); + } + + return CpuUsage; + } +}; + +#endif + +#if ZEN_PLATFORM_WINDOWS + +struct CpuSampler +{ + PDH_HQUERY QueryHandle = nullptr; + PDH_HCOUNTER CounterHandle = nullptr; + bool HasPreviousSample = false; + bool IsWine = false; + ProcStatCpuSampler ProcStat{"Z:\\proc\\stat"}; + + CpuSampler() + { + IsWine = zen::windows::IsRunningOnWine(); + + if (!IsWine) + { + if (PdhOpenQueryW(nullptr, 0, &QueryHandle) == ERROR_SUCCESS) + { + if (PdhAddEnglishCounterW(QueryHandle, L"\\Processor(_Total)\\% Processor Time", 0, &CounterHandle) != ERROR_SUCCESS) + { + CounterHandle = nullptr; + } + } + } + } + + ~CpuSampler() + { + if (QueryHandle) + { + PdhCloseQuery(QueryHandle); + } + } + + float Sample() + { + if (IsWine) + { + return ProcStat.Sample(); + } + + if (!QueryHandle || !CounterHandle) + { + return 0.0f; + } + + PdhCollectQueryData(QueryHandle); + + if (!HasPreviousSample) + { + HasPreviousSample = true; + return 0.0f; + } + + PDH_FMT_COUNTERVALUE CounterValue; + if (PdhGetFormattedCounterValue(CounterHandle, PDH_FMT_DOUBLE, nullptr, &CounterValue) == ERROR_SUCCESS) + { + return static_cast(CounterValue.doubleValue); + } + + return 0.0f; + } +}; + +#elif ZEN_PLATFORM_LINUX + +struct CpuSampler +{ + ProcStatCpuSampler ProcStat; + + float Sample() { return ProcStat.Sample(); } +}; + +#elif ZEN_PLATFORM_MAC + +struct CpuSampler +{ + unsigned long PrevTotalTicks = 0; + unsigned long PrevIdleTicks = 0; + + float Sample() + { + float CpuUsage = 0.0f; + + host_cpu_load_info_data_t CpuLoad; + mach_msg_type_number_t Count = sizeof(CpuLoad) / sizeof(natural_t); + if (host_statistics(mach_host_self(), HOST_CPU_LOAD_INFO, (host_info_t)&CpuLoad, &Count) == KERN_SUCCESS) + { + unsigned long TotalTicks = 0; + for (int i = 0; i < CPU_STATE_MAX; ++i) + { + TotalTicks += CpuLoad.cpu_ticks[i]; + } + unsigned long IdleTicks = CpuLoad.cpu_ticks[CPU_STATE_IDLE]; + + unsigned long TotalDelta = TotalTicks - PrevTotalTicks; + unsigned long IdleDelta = IdleTicks - PrevIdleTicks; + + if (TotalDelta > 0 && PrevTotalTicks > 0) + { + CpuUsage = 100.0f * (TotalDelta - IdleDelta) / TotalDelta; + } + + PrevTotalTicks = TotalTicks; + PrevIdleTicks = IdleTicks; + } + + return CpuUsage; + } +}; + +#endif + +struct SystemMetricsTracker::Impl +{ + using Clock = std::chrono::steady_clock; + + std::mutex Mutex; + CpuSampler Sampler; + float CachedCpuPercent = 0.0f; + Clock::time_point NextSampleTime = Clock::now(); + std::chrono::milliseconds MinInterval; + + explicit Impl(std::chrono::milliseconds InMinInterval) : MinInterval(InMinInterval) {} + + float SampleCpu() + { + const auto Now = Clock::now(); + if (Now >= NextSampleTime) + { + CachedCpuPercent = Sampler.Sample(); + NextSampleTime = Now + MinInterval; + } + return CachedCpuPercent; + } +}; + +SystemMetricsTracker::SystemMetricsTracker(std::chrono::milliseconds MinInterval) : m_Impl(std::make_unique(MinInterval)) +{ +} + +SystemMetricsTracker::~SystemMetricsTracker() = default; + +ExtendedSystemMetrics +SystemMetricsTracker::Query() +{ + ExtendedSystemMetrics Metrics; + static_cast(Metrics) = GetSystemMetrics(); + + std::lock_guard Lock(m_Impl->Mutex); + Metrics.CpuUsagePercent = m_Impl->SampleCpu(); + return Metrics; +} + +/////////////////////////////////////////////////////////////////////////// + std::string_view GetOperatingSystemName() { return ZEN_PLATFORM_NAME; } +std::string_view +GetRuntimePlatformName() +{ +#if ZEN_PLATFORM_WINDOWS + if (zen::windows::IsRunningOnWine()) + { + return "wine"sv; + } + return "windows"sv; +#elif ZEN_PLATFORM_LINUX + return "linux"sv; +#elif ZEN_PLATFORM_MAC + return "macos"sv; +#else + return "unknown"sv; +#endif +} + std::string_view GetCpuName() { @@ -440,4 +611,11 @@ Describe(const SystemMetrics& Metrics, CbWriter& Writer) << "avail_pagefile_mb" << Metrics.AvailPageFileMiB; } +void +Describe(const ExtendedSystemMetrics& Metrics, CbWriter& Writer) +{ + Describe(static_cast(Metrics), Writer); + Writer << "cpu_usage_percent" << Metrics.CpuUsagePercent; +} + } // namespace zen diff --git a/src/zenhorde/hordeagent.cpp b/src/zenhorde/hordeagent.cpp new file mode 100644 index 000000000..819b2d0cb --- /dev/null +++ b/src/zenhorde/hordeagent.cpp @@ -0,0 +1,297 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagent.h" +#include "hordetransportaes.h" + +#include +#include +#include +#include + +#include +#include + +namespace zen::horde { + +HordeAgent::HordeAgent(const MachineInfo& Info) : m_Log(zen::logging::Get("horde.agent")), m_MachineInfo(Info) +{ + ZEN_TRACE_CPU("HordeAgent::Connect"); + + auto Transport = std::make_unique(Info); + if (!Transport->IsValid()) + { + ZEN_WARN("failed to create TCP transport to '{}:{}'", Info.GetConnectionAddress(), Info.GetConnectionPort()); + return; + } + + // The 64-byte nonce is always sent unencrypted as the first thing on the wire. + // The Horde agent uses this to identify which lease this connection belongs to. + Transport->Send(Info.Nonce, sizeof(Info.Nonce)); + + std::unique_ptr FinalTransport = std::move(Transport); + if (Info.EncryptionMode == Encryption::AES) + { + FinalTransport = std::make_unique(Info.Key, std::move(FinalTransport)); + if (!FinalTransport->IsValid()) + { + ZEN_WARN("failed to create AES transport"); + return; + } + } + + // Create multiplexed socket and channels + m_Socket = std::make_unique(std::move(FinalTransport)); + + // Channel 0 is the agent control channel (handles Attach/Fork handshake). + // Channel 100 is the child I/O channel (handles file upload and remote execution). + Ref AgentComputeChannel = m_Socket->CreateChannel(0); + Ref ChildComputeChannel = m_Socket->CreateChannel(100); + + if (!AgentComputeChannel || !ChildComputeChannel) + { + ZEN_WARN("failed to create compute channels"); + return; + } + + m_AgentChannel = std::make_unique(std::move(AgentComputeChannel)); + m_ChildChannel = std::make_unique(std::move(ChildComputeChannel)); + + m_IsValid = true; +} + +HordeAgent::~HordeAgent() +{ + CloseConnection(); +} + +bool +HordeAgent::BeginCommunication() +{ + ZEN_TRACE_CPU("HordeAgent::BeginCommunication"); + + if (!m_IsValid) + { + return false; + } + + // Start the send/recv pump threads + m_Socket->StartCommunication(); + + // Wait for Attach on agent channel + AgentMessageType Type = m_AgentChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on agent channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on agent channel, got 0x{:02x}", static_cast(Type)); + return false; + } + + // Fork tells the remote agent to create child channel 100 with a 4MB buffer. + // After this, the agent will send an Attach on the child channel. + m_AgentChannel->Fork(100, 4 * 1024 * 1024); + + // Wait for Attach on child channel + Type = m_ChildChannel->ReadResponse(5000); + if (Type == AgentMessageType::None) + { + ZEN_WARN("timed out waiting for Attach on child channel"); + return false; + } + if (Type != AgentMessageType::Attach) + { + ZEN_WARN("expected Attach on child channel, got 0x{:02x}", static_cast(Type)); + return false; + } + + return true; +} + +bool +HordeAgent::UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator) +{ + ZEN_TRACE_CPU("HordeAgent::UploadBinaries"); + + m_ChildChannel->UploadFiles("", BundleLocator.c_str()); + + std::unordered_map> BlobFiles; + + auto FindOrOpenBlob = [&](std::string_view Locator) -> BasicFile* { + std::string Key(Locator); + + if (auto It = BlobFiles.find(Key); It != BlobFiles.end()) + { + return It->second.get(); + } + + const std::filesystem::path Path = BundleDir / (Key + ".blob"); + std::error_code Ec; + auto File = std::make_unique(); + File->Open(Path, BasicFile::Mode::kRead, Ec); + + if (Ec) + { + ZEN_ERROR("cannot read blob file: '{}'", Path); + return nullptr; + } + + BasicFile* Ptr = File.get(); + BlobFiles.emplace(std::move(Key), std::move(File)); + return Ptr; + }; + + // The upload protocol is request-driven: we send WriteFiles, then the remote agent + // sends ReadBlob requests for each blob it needs. We respond with Blob data until + // the agent sends WriteFilesResponse indicating the upload is complete. + constexpr int32_t ReadResponseTimeoutMs = 1000; + + for (;;) + { + bool TimedOut = false; + + if (AgentMessageType Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs, &TimedOut); Type != AgentMessageType::ReadBlob) + { + if (TimedOut) + { + continue; + } + // End of stream - check if it was a successful upload + if (Type == AgentMessageType::WriteFilesResponse) + { + return true; + } + else if (Type == AgentMessageType::Exception) + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("upload exception: {} - {}", Ex.Message, Ex.Description); + } + else + { + ZEN_ERROR("unexpected message type 0x{:02x} during upload", static_cast(Type)); + } + return false; + } + + BlobRequest Req; + m_ChildChannel->ReadBlobRequest(Req); + + BasicFile* File = FindOrOpenBlob(Req.Locator); + if (!File) + { + return false; + } + + // Read from offset to end of file + const uint64_t TotalSize = File->FileSize(); + const uint64_t Offset = static_cast(Req.Offset); + if (Offset >= TotalSize) + { + ZEN_ERROR("upload got request for data beyond end of file: offset={}, length={}, total_size={}", Offset, Req.Length, TotalSize); + m_ChildChannel->Blob(nullptr, 0); + continue; + } + + const IoBuffer Data = File->ReadRange(Offset, Min(Req.Length, TotalSize - Offset)); + m_ChildChannel->Blob(static_cast(Data.GetData()), Data.GetSize()); + } +} + +void +HordeAgent::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + bool UseWine) +{ + ZEN_TRACE_CPU("HordeAgent::Execute"); + m_ChildChannel + ->Execute(Exe, Args, NumArgs, WorkingDir, EnvVars, NumEnvVars, UseWine ? ExecuteProcessFlags::UseWine : ExecuteProcessFlags::None); +} + +bool +HordeAgent::Poll(bool LogOutput) +{ + constexpr int32_t ReadResponseTimeoutMs = 100; + AgentMessageType Type; + + while ((Type = m_ChildChannel->ReadResponse(ReadResponseTimeoutMs)) != AgentMessageType::None) + { + switch (Type) + { + case AgentMessageType::ExecuteOutput: + { + if (LogOutput && m_ChildChannel->GetResponseSize() > 0) + { + const char* ResponseData = static_cast(m_ChildChannel->GetResponseData()); + size_t ResponseSize = m_ChildChannel->GetResponseSize(); + + // Trim trailing newlines + while (ResponseSize > 0 && (ResponseData[ResponseSize - 1] == '\n' || ResponseData[ResponseSize - 1] == '\r')) + { + --ResponseSize; + } + + if (ResponseSize > 0) + { + const std::string_view Output(ResponseData, ResponseSize); + ZEN_INFO("[remote] {}", Output); + } + } + break; + } + + case AgentMessageType::ExecuteResult: + { + if (m_ChildChannel->GetResponseSize() == sizeof(int32_t)) + { + int32_t ExitCode; + memcpy(&ExitCode, m_ChildChannel->GetResponseData(), sizeof(int32_t)); + ZEN_INFO("remote process exited with code {}", ExitCode); + } + m_IsValid = false; + return false; + } + + case AgentMessageType::Exception: + { + ExceptionInfo Ex; + m_ChildChannel->ReadException(Ex); + ZEN_ERROR("exception: {} - {}", Ex.Message, Ex.Description); + m_HasErrors = true; + break; + } + + default: + break; + } + } + + return m_IsValid && !m_HasErrors; +} + +void +HordeAgent::CloseConnection() +{ + if (m_ChildChannel) + { + m_ChildChannel->Close(); + } + if (m_AgentChannel) + { + m_AgentChannel->Close(); + } +} + +bool +HordeAgent::IsValid() const +{ + return m_IsValid && !m_HasErrors; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagent.h b/src/zenhorde/hordeagent.h new file mode 100644 index 000000000..e0ae89ead --- /dev/null +++ b/src/zenhorde/hordeagent.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordeagentmessage.h" +#include "hordecomputesocket.h" + +#include + +#include + +#include +#include +#include + +namespace zen::horde { + +/** Manages the lifecycle of a single Horde compute agent. + * + * Handles the full connection sequence for one provisioned machine: + * 1. Connect via TCP transport (with optional AES encryption wrapping) + * 2. Create a multiplexed ComputeSocket with agent (channel 0) and child (channel 100) + * 3. Perform the Attach/Fork handshake to establish the child channel + * 4. Upload zenserver binary via the WriteFiles/ReadBlob protocol + * 5. Execute zenserver remotely via ExecuteV2 + * 6. Poll for ExecuteOutput (stdout) and ExecuteResult (exit code) + */ +class HordeAgent +{ +public: + explicit HordeAgent(const MachineInfo& Info); + ~HordeAgent(); + + HordeAgent(const HordeAgent&) = delete; + HordeAgent& operator=(const HordeAgent&) = delete; + + /** Perform the channel setup handshake (Attach on agent channel, Fork, Attach on child channel). + * Returns false if the handshake times out or receives an unexpected message. */ + bool BeginCommunication(); + + /** Upload binary files to the remote agent. + * @param BundleDir Directory containing .blob files. + * @param BundleLocator Locator string identifying the bundle (from CreateBundle). */ + bool UploadBinaries(const std::filesystem::path& BundleDir, const std::string& BundleLocator); + + /** Execute a command on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir = nullptr, + const char* const* EnvVars = nullptr, + size_t NumEnvVars = 0, + bool UseWine = false); + + /** Poll for output and results. Returns true if the agent is still running. + * When LogOutput is true, remote stdout is logged via ZEN_INFO. */ + bool Poll(bool LogOutput = true); + + void CloseConnection(); + bool IsValid() const; + + const MachineInfo& GetMachineInfo() const { return m_MachineInfo; } + +private: + LoggerRef Log() { return m_Log; } + + std::unique_ptr m_Socket; + std::unique_ptr m_AgentChannel; ///< Channel 0: agent control + std::unique_ptr m_ChildChannel; ///< Channel 100: child I/O + + LoggerRef m_Log; + bool m_IsValid = false; + bool m_HasErrors = false; + MachineInfo m_MachineInfo; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp new file mode 100644 index 000000000..998134a96 --- /dev/null +++ b/src/zenhorde/hordeagentmessage.cpp @@ -0,0 +1,340 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordeagentmessage.h" + +#include + +#include +#include + +namespace zen::horde { + +AgentMessageChannel::AgentMessageChannel(Ref Channel) : m_Channel(std::move(Channel)) +{ +} + +AgentMessageChannel::~AgentMessageChannel() = default; + +void +AgentMessageChannel::Close() +{ + CreateMessage(AgentMessageType::None, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Ping() +{ + CreateMessage(AgentMessageType::Ping, 0); + FlushMessage(); +} + +void +AgentMessageChannel::Fork(int ChannelId, int BufferSize) +{ + CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int)); + WriteInt32(ChannelId); + WriteInt32(BufferSize); + FlushMessage(); +} + +void +AgentMessageChannel::Attach() +{ + CreateMessage(AgentMessageType::Attach, 0); + FlushMessage(); +} + +void +AgentMessageChannel::UploadFiles(const char* Path, const char* Locator) +{ + CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20); + WriteString(Path); + WriteString(Locator); + FlushMessage(); +} + +void +AgentMessageChannel::Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags) +{ + size_t RequiredSize = 50 + strlen(Exe); + for (size_t i = 0; i < NumArgs; ++i) + { + RequiredSize += strlen(Args[i]) + 10; + } + if (WorkingDir) + { + RequiredSize += strlen(WorkingDir) + 10; + } + for (size_t i = 0; i < NumEnvVars; ++i) + { + RequiredSize += strlen(EnvVars[i]) + 20; + } + + CreateMessage(AgentMessageType::ExecuteV2, RequiredSize); + WriteString(Exe); + + WriteUnsignedVarInt(NumArgs); + for (size_t i = 0; i < NumArgs; ++i) + { + WriteString(Args[i]); + } + + WriteOptionalString(WorkingDir); + + // ExecuteV2 protocol requires env vars as separate key/value pairs. + // Callers pass "KEY=VALUE" strings; we split on the first '=' here. + WriteUnsignedVarInt(NumEnvVars); + for (size_t i = 0; i < NumEnvVars; ++i) + { + const char* Eq = strchr(EnvVars[i], '='); + assert(Eq != nullptr); + + WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i])); + if (*(Eq + 1) == '\0') + { + WriteOptionalString(nullptr); + } + else + { + WriteOptionalString(Eq + 1); + } + } + + WriteInt32(static_cast(Flags)); + FlushMessage(); +} + +void +AgentMessageChannel::Blob(const uint8_t* Data, size_t Length) +{ + // Blob responses are chunked to fit within the compute buffer's chunk size. + // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields). + const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength; + for (size_t ChunkOffset = 0; ChunkOffset < Length;) + { + const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize); + + CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128); + WriteInt32(static_cast(ChunkOffset)); + WriteInt32(static_cast(Length)); + WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength); + FlushMessage(); + + ChunkOffset += ChunkLength; + } +} + +AgentMessageType +AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut) +{ + // Deferred advance: the previous response's buffer is only released when the next + // ReadResponse is called. This allows callers to read response data between calls + // without copying, since the pointer comes directly from the ring buffer. + if (m_ResponseData) + { + m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength); + m_ResponseData = nullptr; + m_ResponseLength = 0; + } + + const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + uint32_t Length; + memcpy(&Length, Header + 1, sizeof(uint32_t)); + + Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut); + if (!Header) + { + return AgentMessageType::None; + } + + m_ResponseType = static_cast(Header[0]); + m_ResponseData = Header + MessageHeaderLength; + m_ResponseLength = Length; + + return m_ResponseType; +} + +void +AgentMessageChannel::ReadException(ExceptionInfo& Ex) +{ + assert(m_ResponseType == AgentMessageType::Exception); + const uint8_t* Pos = m_ResponseData; + Ex.Message = ReadString(&Pos); + Ex.Description = ReadString(&Pos); +} + +int +AgentMessageChannel::ReadExecuteResult() +{ + assert(m_ResponseType == AgentMessageType::ExecuteResult); + const uint8_t* Pos = m_ResponseData; + return ReadInt32(&Pos); +} + +void +AgentMessageChannel::ReadBlobRequest(BlobRequest& Req) +{ + assert(m_ResponseType == AgentMessageType::ReadBlob); + const uint8_t* Pos = m_ResponseData; + Req.Locator = ReadString(&Pos); + Req.Offset = ReadUnsignedVarInt(&Pos); + Req.Length = ReadUnsignedVarInt(&Pos); +} + +void +AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength) +{ + m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength); + m_RequestData[0] = static_cast(Type); + m_MaxRequestSize = MaxLength; + m_RequestSize = 0; +} + +void +AgentMessageChannel::FlushMessage() +{ + const uint32_t Size = static_cast(m_RequestSize); + memcpy(&m_RequestData[1], &Size, sizeof(uint32_t)); + m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize); + m_RequestSize = 0; + m_MaxRequestSize = 0; + m_RequestData = nullptr; +} + +void +AgentMessageChannel::WriteInt32(int Value) +{ + WriteFixedLengthBytes(reinterpret_cast(&Value), sizeof(int)); +} + +int +AgentMessageChannel::ReadInt32(const uint8_t** Pos) +{ + int Value; + memcpy(&Value, *Pos, sizeof(int)); + *Pos += sizeof(int); + return Value; +} + +void +AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length) +{ + assert(m_RequestSize + Length <= m_MaxRequestSize); + memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length); + m_RequestSize += Length; +} + +const uint8_t* +AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length) +{ + const uint8_t* Data = *Pos; + *Pos += Length; + return Data; +} + +size_t +AgentMessageChannel::MeasureUnsignedVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2_64(static_cast(Value)) / 7) + 1; +} + +void +AgentMessageChannel::WriteUnsignedVarInt(size_t Value) +{ + const size_t ByteCount = MeasureUnsignedVarInt(Value); + assert(m_RequestSize + ByteCount <= m_MaxRequestSize); + + uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast(Value); + Value >>= 8; + } + Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); + + m_RequestSize += ByteCount; +} + +size_t +AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos) +{ + const uint8_t* Data = *Pos; + const uint8_t FirstByte = Data[0]; + const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast(FirstByte))) + 1 - 24; + + size_t Value = static_cast(FirstByte & (0xFF >> NumBytes)); + for (size_t i = 1; i < NumBytes; ++i) + { + Value <<= 8; + Value |= Data[i]; + } + + *Pos += NumBytes; + return Value; +} + +size_t +AgentMessageChannel::MeasureString(const char* Text) const +{ + const size_t Length = strlen(Text); + return MeasureUnsignedVarInt(Length) + Length; +} + +void +AgentMessageChannel::WriteString(const char* Text) +{ + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length); + WriteFixedLengthBytes(reinterpret_cast(Text), Length); +} + +void +AgentMessageChannel::WriteString(std::string_view Text) +{ + WriteUnsignedVarInt(Text.size()); + WriteFixedLengthBytes(reinterpret_cast(Text.data()), Text.size()); +} + +std::string_view +AgentMessageChannel::ReadString(const uint8_t** Pos) +{ + const size_t Length = ReadUnsignedVarInt(Pos); + const char* Start = reinterpret_cast(ReadFixedLengthBytes(Pos, Length)); + return std::string_view(Start, Length); +} + +void +AgentMessageChannel::WriteOptionalString(const char* Text) +{ + // Optional strings use length+1 encoding: 0 means null/absent, + // N>0 means a string of length N-1 follows. This matches the UE + // FAgentMessageChannel serialization convention. + if (!Text) + { + WriteUnsignedVarInt(0); + } + else + { + const size_t Length = strlen(Text); + WriteUnsignedVarInt(Length + 1); + WriteFixedLengthBytes(reinterpret_cast(Text), Length); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeagentmessage.h b/src/zenhorde/hordeagentmessage.h new file mode 100644 index 000000000..38c4375fd --- /dev/null +++ b/src/zenhorde/hordeagentmessage.h @@ -0,0 +1,161 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include "hordecomputechannel.h" + +#include +#include +#include +#include +#include + +namespace zen::horde { + +/** Agent message types matching the UE EAgentMessageType byte values. + * These are the message opcodes exchanged over the agent/child channels. */ +enum class AgentMessageType : uint8_t +{ + None = 0x00, + Ping = 0x01, + Exception = 0x02, + Fork = 0x03, + Attach = 0x04, + WriteFiles = 0x10, + WriteFilesResponse = 0x11, + DeleteFiles = 0x12, + ExecuteV2 = 0x22, + ExecuteOutput = 0x17, + ExecuteResult = 0x18, + ReadBlob = 0x20, + ReadBlobResponse = 0x21, +}; + +/** Flags for the ExecuteV2 message. */ +enum class ExecuteProcessFlags : uint8_t +{ + None = 0, + UseWine = 1, ///< Run the executable under Wine on Linux agents +}; + +/** Parsed exception information from an Exception message. */ +struct ExceptionInfo +{ + std::string_view Message; + std::string_view Description; +}; + +/** Parsed blob read request from a ReadBlob message. */ +struct BlobRequest +{ + std::string_view Locator; + size_t Offset = 0; + size_t Length = 0; +}; + +/** Channel for sending and receiving agent messages over a ComputeChannel. + * + * Implements the Horde agent message protocol, matching the UE + * FAgentMessageChannel serialization format exactly. Messages are framed as + * [type (1B)][payload length (4B)][payload]. Strings use length-prefixed UTF-8; + * integers use variable-length encoding. + * + * The protocol has two directions: + * - Requests (initiator -> remote): Close, Ping, Fork, Attach, UploadFiles, Execute, Blob + * - Responses (remote -> initiator): ReadResponse returns the type, then call the + * appropriate Read* method to parse the payload. + */ +class AgentMessageChannel +{ +public: + explicit AgentMessageChannel(Ref Channel); + ~AgentMessageChannel(); + + AgentMessageChannel(const AgentMessageChannel&) = delete; + AgentMessageChannel& operator=(const AgentMessageChannel&) = delete; + + // --- Requests (Initiator -> Remote) --- + + /** Close the channel. */ + void Close(); + + /** Send a keepalive ping. */ + void Ping(); + + /** Fork communication to a new channel with the given ID and buffer size. */ + void Fork(int ChannelId, int BufferSize); + + /** Send an attach request (used during channel setup handshake). */ + void Attach(); + + /** Request the remote agent to write files from the given bundle locator. */ + void UploadFiles(const char* Path, const char* Locator); + + /** Execute a process on the remote machine. */ + void Execute(const char* Exe, + const char* const* Args, + size_t NumArgs, + const char* WorkingDir, + const char* const* EnvVars, + size_t NumEnvVars, + ExecuteProcessFlags Flags = ExecuteProcessFlags::None); + + /** Send blob data in response to a ReadBlob request. */ + void Blob(const uint8_t* Data, size_t Length); + + // --- Responses (Remote -> Initiator) --- + + /** Read the next response message. Returns the message type, or None on timeout. + * After this returns, use GetResponseData()/GetResponseSize() or the typed + * Read* methods to access the payload. */ + AgentMessageType ReadResponse(int32_t TimeoutMs = -1, bool* OutTimedOut = nullptr); + + const void* GetResponseData() const { return m_ResponseData; } + size_t GetResponseSize() const { return m_ResponseLength; } + + /** Parse an Exception response payload. */ + void ReadException(ExceptionInfo& Ex); + + /** Parse an ExecuteResult response payload. Returns the exit code. */ + int ReadExecuteResult(); + + /** Parse a ReadBlob response payload into a BlobRequest. */ + void ReadBlobRequest(BlobRequest& Req); + +private: + static constexpr size_t MessageHeaderLength = 5; ///< [type(1B)][length(4B)] + + Ref m_Channel; + + uint8_t* m_RequestData = nullptr; + size_t m_RequestSize = 0; + size_t m_MaxRequestSize = 0; + + AgentMessageType m_ResponseType = AgentMessageType::None; + const uint8_t* m_ResponseData = nullptr; + size_t m_ResponseLength = 0; + + void CreateMessage(AgentMessageType Type, size_t MaxLength); + void FlushMessage(); + + void WriteInt32(int Value); + static int ReadInt32(const uint8_t** Pos); + + void WriteFixedLengthBytes(const uint8_t* Data, size_t Length); + static const uint8_t* ReadFixedLengthBytes(const uint8_t** Pos, size_t Length); + + static size_t MeasureUnsignedVarInt(size_t Value); + void WriteUnsignedVarInt(size_t Value); + static size_t ReadUnsignedVarInt(const uint8_t** Pos); + + size_t MeasureString(const char* Text) const; + void WriteString(const char* Text); + void WriteString(std::string_view Text); + static std::string_view ReadString(const uint8_t** Pos); + + void WriteOptionalString(const char* Text); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.cpp b/src/zenhorde/hordebundle.cpp new file mode 100644 index 000000000..d3974bc28 --- /dev/null +++ b/src/zenhorde/hordebundle.cpp @@ -0,0 +1,619 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordebundle.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace zen::horde { + +static LoggerRef +Log() +{ + static auto s_Logger = zen::logging::Get("horde.bundle"); + return s_Logger; +} + +static constexpr uint8_t PacketSignature[3] = {'U', 'B', 'N'}; +static constexpr uint8_t PacketVersion = 5; +static constexpr int32_t CurrentPacketBaseIdx = -2; +static constexpr int ImportBias = 3; +static constexpr uint32_t ChunkSize = 64 * 1024; // 64KB fixed chunks +static constexpr uint32_t LargeFileThreshold = 128 * 1024; // 128KB + +// BlobType: 20 bytes each = FGuid (16 bytes, 4x uint32 LE) + Version (int32 LE) +// Values from UE SDK: GUIDs stored as 4 uint32 LE values. + +// ChunkLeaf v1: {0xB27AFB68, 0x4A4B9E20, 0x8A78D8A4, 0x39D49840} +static constexpr uint8_t BlobType_ChunkLeafV1[20] = {0x68, 0xFB, 0x7A, 0xB2, 0x20, 0x9E, 0x4B, 0x4A, 0xA4, 0xD8, + 0x78, 0x8A, 0x40, 0x98, 0xD4, 0x39, 0x01, 0x00, 0x00, 0x00}; // version 1 + +// ChunkInterior v2: {0xF4DEDDBC, 0x4C7A70CB, 0x11F04783, 0xB9CDCCAF} +static constexpr uint8_t BlobType_ChunkInteriorV2[20] = {0xBC, 0xDD, 0xDE, 0xF4, 0xCB, 0x70, 0x7A, 0x4C, 0x83, 0x47, + 0xF0, 0x11, 0xAF, 0xCC, 0xCD, 0xB9, 0x02, 0x00, 0x00, 0x00}; // version 2 + +// Directory v1: {0x0714EC11, 0x4D07291A, 0x8AE77F86, 0x799980D6} +static constexpr uint8_t BlobType_DirectoryV1[20] = {0x11, 0xEC, 0x14, 0x07, 0x1A, 0x29, 0x07, 0x4D, 0x86, 0x7F, + 0xE7, 0x8A, 0xD6, 0x80, 0x99, 0x79, 0x01, 0x00, 0x00, 0x00}; // version 1 + +static constexpr size_t BlobTypeSize = 20; + +// ─── VarInt helpers (UE format) ───────────────────────────────────────────── + +static size_t +MeasureVarInt(size_t Value) +{ + if (Value == 0) + { + return 1; + } + return (FloorLog2(static_cast(Value)) / 7) + 1; +} + +static void +WriteVarInt(std::vector& Buffer, size_t Value) +{ + const size_t ByteCount = MeasureVarInt(Value); + const size_t Offset = Buffer.size(); + Buffer.resize(Offset + ByteCount); + + uint8_t* Output = Buffer.data() + Offset; + for (size_t i = 1; i < ByteCount; ++i) + { + Output[ByteCount - i] = static_cast(Value); + Value >>= 8; + } + Output[0] = static_cast((0xFF << (9 - static_cast(ByteCount))) | static_cast(Value)); +} + +// ─── Binary helpers ───────────────────────────────────────────────────────── + +static void +WriteLE32(std::vector& Buffer, int32_t Value) +{ + uint8_t Bytes[4]; + memcpy(Bytes, &Value, 4); + Buffer.insert(Buffer.end(), Bytes, Bytes + 4); +} + +static void +WriteByte(std::vector& Buffer, uint8_t Value) +{ + Buffer.push_back(Value); +} + +static void +WriteBytes(std::vector& Buffer, const void* Data, size_t Size) +{ + auto* Ptr = static_cast(Data); + Buffer.insert(Buffer.end(), Ptr, Ptr + Size); +} + +static void +WriteString(std::vector& Buffer, std::string_view Str) +{ + WriteVarInt(Buffer, Str.size()); + WriteBytes(Buffer, Str.data(), Str.size()); +} + +static void +AlignTo4(std::vector& Buffer) +{ + while (Buffer.size() % 4 != 0) + { + Buffer.push_back(0); + } +} + +static void +PatchLE32(std::vector& Buffer, size_t Offset, int32_t Value) +{ + memcpy(Buffer.data() + Offset, &Value, 4); +} + +// ─── Packet builder ───────────────────────────────────────────────────────── + +// Builds a single uncompressed Horde V2 packet. Layout: +// [Signature(3) + Version(1) + PacketLength(4)] 8 bytes (header) +// [TypeTableOffset(4) + ImportTableOffset(4) + ExportTableOffset(4)] 12 bytes +// [Export data...] +// [Type table: count(4) + count * 20 bytes] +// [Import table: count(4) + (count+1) offset entries(4 each) + import data] +// [Export table: count(4) + (count+1) offset entries(4 each)] +// +// ALL offsets are absolute from byte 0 of the full packet (including the 8-byte header). +// PacketLength in the header = total packet size including the 8-byte header. + +struct PacketBuilder +{ + std::vector Data; + std::vector ExportOffsets; // Absolute byte offset of each export from byte 0 + + // Type table: unique 20-byte BlobType entries + std::vector Types; + + // Import table entries: (baseIdx, fragment) + struct ImportEntry + { + int32_t BaseIdx; + std::string Fragment; + }; + std::vector Imports; + + // Current export's start offset (absolute from byte 0) + size_t CurrentExportStart = 0; + + PacketBuilder() + { + // Reserve packet header (8 bytes) + table offsets (12 bytes) = 20 bytes + Data.resize(20, 0); + + // Write signature + Data[0] = PacketSignature[0]; + Data[1] = PacketSignature[1]; + Data[2] = PacketSignature[2]; + Data[3] = PacketVersion; + // PacketLength, TypeTableOffset, ImportTableOffset, ExportTableOffset + // will be patched in Finish() + } + + int AddType(const uint8_t* BlobType) + { + for (size_t i = 0; i < Types.size(); ++i) + { + if (memcmp(Types[i], BlobType, BlobTypeSize) == 0) + { + return static_cast(i); + } + } + Types.push_back(BlobType); + return static_cast(Types.size() - 1); + } + + int AddImport(int32_t BaseIdx, std::string Fragment) + { + Imports.push_back({BaseIdx, std::move(Fragment)}); + return static_cast(Imports.size() - 1); + } + + void BeginExport() + { + AlignTo4(Data); + CurrentExportStart = Data.size(); + // Reserve space for payload length + WriteLE32(Data, 0); + } + + // Write raw payload data into the current export + void WritePayload(const void* Payload, size_t Size) { WriteBytes(Data, Payload, Size); } + + // Complete the current export: patches payload length, writes type+imports metadata + int CompleteExport(const uint8_t* BlobType, const std::vector& ImportIndices) + { + const int ExportIndex = static_cast(ExportOffsets.size()); + + // Patch payload length (does not include the 4-byte length field itself) + const size_t PayloadStart = CurrentExportStart + 4; + const int32_t PayloadLen = static_cast(Data.size() - PayloadStart); + PatchLE32(Data, CurrentExportStart, PayloadLen); + + // Write type index (varint) + const int TypeIdx = AddType(BlobType); + WriteVarInt(Data, static_cast(TypeIdx)); + + // Write import count + indices + WriteVarInt(Data, ImportIndices.size()); + for (int Idx : ImportIndices) + { + WriteVarInt(Data, static_cast(Idx)); + } + + // Record export offset (absolute from byte 0) + ExportOffsets.push_back(static_cast(CurrentExportStart)); + + return ExportIndex; + } + + // Finalize the packet: write type/import/export tables, patch header. + std::vector Finish() + { + AlignTo4(Data); + + // ── Type table: count(int32) + count * BlobTypeSize bytes ── + const int32_t TypeTableOffset = static_cast(Data.size()); + WriteLE32(Data, static_cast(Types.size())); + for (const uint8_t* TypeEntry : Types) + { + WriteBytes(Data, TypeEntry, BlobTypeSize); + } + + // ── Import table: count(int32) + (count+1) offsets(int32 each) + import data ── + const int32_t ImportTableOffset = static_cast(Data.size()); + const int32_t ImportCount = static_cast(Imports.size()); + WriteLE32(Data, ImportCount); + + // Reserve space for (count+1) offset entries — will be patched below + const size_t ImportOffsetsStart = Data.size(); + for (int32_t i = 0; i <= ImportCount; ++i) + { + WriteLE32(Data, 0); // placeholder + } + + // Write import data and record offsets + for (int32_t i = 0; i < ImportCount; ++i) + { + // Record absolute offset of this import's data + PatchLE32(Data, ImportOffsetsStart + static_cast(i) * 4, static_cast(Data.size())); + + ImportEntry& Imp = Imports[static_cast(i)]; + // BaseIdx encoded as unsigned VarInt with bias: VarInt(BaseIdx + ImportBias) + const size_t EncodedBaseIdx = static_cast(static_cast(Imp.BaseIdx) + ImportBias); + WriteVarInt(Data, EncodedBaseIdx); + // Fragment: raw UTF-8 bytes, NO length prefix (length determined by offset table) + WriteBytes(Data, Imp.Fragment.data(), Imp.Fragment.size()); + } + + // Sentinel offset (points past the last import's data) + PatchLE32(Data, ImportOffsetsStart + static_cast(ImportCount) * 4, static_cast(Data.size())); + + // ── Export table: count(int32) + (count+1) offsets(int32 each) ── + const int32_t ExportTableOffset = static_cast(Data.size()); + const int32_t ExportCount = static_cast(ExportOffsets.size()); + WriteLE32(Data, ExportCount); + + for (int32_t Off : ExportOffsets) + { + WriteLE32(Data, Off); + } + // Sentinel: points to the start of the type table (end of export data region) + WriteLE32(Data, TypeTableOffset); + + // ── Patch header ── + // PacketLength = total packet size including the 8-byte header + const int32_t PacketLength = static_cast(Data.size()); + PatchLE32(Data, 4, PacketLength); + PatchLE32(Data, 8, TypeTableOffset); + PatchLE32(Data, 12, ImportTableOffset); + PatchLE32(Data, 16, ExportTableOffset); + + return std::move(Data); + } +}; + +// ─── Encoded packet wrapper ───────────────────────────────────────────────── + +// Wraps an uncompressed packet with the encoded header: +// [Signature(3) + Version(1) + HeaderLength(4)] 8 bytes +// [DecompressedLength(4)] 4 bytes +// [CompressionFormat(1): 0=None] 1 byte +// [PacketData...] +// +// HeaderLength = total encoded packet size INCLUDING the 8-byte outer header. + +static std::vector +EncodePacket(std::vector UncompressedPacket) +{ + const int32_t DecompressedLen = static_cast(UncompressedPacket.size()); + // HeaderLength includes the 8-byte outer signature header itself + const int32_t HeaderLength = 8 + 4 + 1 + DecompressedLen; + + std::vector Encoded; + Encoded.reserve(static_cast(HeaderLength)); + + // Outer signature: 'U','B','N', version=5, HeaderLength (LE int32) + WriteByte(Encoded, PacketSignature[0]); // 'U' + WriteByte(Encoded, PacketSignature[1]); // 'B' + WriteByte(Encoded, PacketSignature[2]); // 'N' + WriteByte(Encoded, PacketVersion); // 5 + WriteLE32(Encoded, HeaderLength); + + // Decompressed length + compression format + WriteLE32(Encoded, DecompressedLen); + WriteByte(Encoded, 0); // CompressionFormat::None + + // Packet data + WriteBytes(Encoded, UncompressedPacket.data(), UncompressedPacket.size()); + + return Encoded; +} + +// ─── Bundle blob name generation ──────────────────────────────────────────── + +static std::string +GenerateBlobName() +{ + static std::atomic s_Counter{0}; + + const int Pid = GetCurrentProcessId(); + + auto Now = std::chrono::steady_clock::now().time_since_epoch(); + auto Ms = std::chrono::duration_cast(Now).count(); + + ExtendableStringBuilder<64> Name; + Name << Pid << "_" << Ms << "_" << s_Counter.fetch_add(1); + return std::string(Name.ToView()); +} + +// ─── File info for bundling ───────────────────────────────────────────────── + +struct FileInfo +{ + std::filesystem::path Path; + std::string Name; // Filename only (for directory entry) + uint64_t FileSize; + IoHash ContentHash; // IoHash of file content + BLAKE3 StreamHash; // Full BLAKE3 for stream hash + int DirectoryExportImportIndex; // Import index referencing this file's root export + IoHash RootExportHash; // IoHash of the root export for this file +}; + +// ─── CreateBundle implementation ──────────────────────────────────────────── + +bool +BundleCreator::CreateBundle(const std::vector& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult) +{ + ZEN_TRACE_CPU("BundleCreator::CreateBundle"); + + std::error_code Ec; + + // Collect files that exist + std::vector ValidFiles; + for (const BundleFile& F : Files) + { + if (!std::filesystem::exists(F.Path, Ec)) + { + if (F.Optional) + { + continue; + } + ZEN_ERROR("required bundle file does not exist: {}", F.Path.string()); + return false; + } + FileInfo Info; + Info.Path = F.Path; + Info.Name = F.Path.filename().string(); + Info.FileSize = std::filesystem::file_size(F.Path, Ec); + if (Ec) + { + ZEN_ERROR("failed to get file size: {}", F.Path.string()); + return false; + } + ValidFiles.push_back(std::move(Info)); + } + + if (ValidFiles.empty()) + { + ZEN_ERROR("no valid files to bundle"); + return false; + } + + std::filesystem::create_directories(OutputDir, Ec); + if (Ec) + { + ZEN_ERROR("failed to create output directory: {}", OutputDir.string()); + return false; + } + + const std::string BlobName = GenerateBlobName(); + PacketBuilder Packet; + + // Process each file: create chunk exports + for (FileInfo& Info : ValidFiles) + { + BasicFile File; + File.Open(Info.Path, BasicFile::Mode::kRead, Ec); + if (Ec) + { + ZEN_ERROR("failed to open file: {}", Info.Path.string()); + return false; + } + + // Compute stream hash (full BLAKE3) and content hash (IoHash) while reading + BLAKE3Stream StreamHasher; + IoHashStream ContentHasher; + + if (Info.FileSize <= LargeFileThreshold) + { + // Small file: single chunk leaf export + IoBuffer Content = File.ReadAll(); + const auto* Data = static_cast(Content.GetData()); + const size_t Size = Content.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExportIndex = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + Info.RootExportHash = ChunkHash; + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Add import for this file's root export (references export within same packet) + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + else + { + // Large file: split into fixed 64KB chunks, then create interior node + std::vector ChunkExportIndices; + std::vector ChunkHashes; + + uint64_t Remaining = Info.FileSize; + uint64_t Offset = 0; + + while (Remaining > 0) + { + const uint64_t ReadSize = std::min(static_cast(ChunkSize), Remaining); + IoBuffer Chunk = File.ReadRange(Offset, ReadSize); + const auto* Data = static_cast(Chunk.GetData()); + const size_t Size = Chunk.GetSize(); + + StreamHasher.Append(Data, Size); + ContentHasher.Append(Data, Size); + + Packet.BeginExport(); + Packet.WritePayload(Data, Size); + + const IoHash ChunkHash = IoHash::HashBuffer(Data, Size); + const int ExpIdx = Packet.CompleteExport(BlobType_ChunkLeafV1, {}); + + ChunkExportIndices.push_back(ExpIdx); + ChunkHashes.push_back(ChunkHash); + + Offset += ReadSize; + Remaining -= ReadSize; + } + + Info.ContentHash = ContentHasher.GetHash(); + Info.StreamHash = StreamHasher.GetHash(); + + // Create interior node referencing all chunk leaves + // Interior payload: for each child: [IoHash(20)][node_type=1(1)] + imports + std::vector InteriorImports; + for (size_t i = 0; i < ChunkExportIndices.size(); ++i) + { + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << ChunkExportIndices[i]; + const int ImportIdx = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + InteriorImports.push_back(ImportIdx); + } + + Packet.BeginExport(); + + // Write interior payload: [hash(20)][type(1)] per child + for (size_t i = 0; i < ChunkHashes.size(); ++i) + { + Packet.WritePayload(ChunkHashes[i].Hash, sizeof(IoHash)); + const uint8_t NodeType = 1; // ChunkNode type + Packet.WritePayload(&NodeType, 1); + } + + // Hash the interior payload to get the interior node hash + const IoHash InteriorHash = IoHash::HashBuffer(Packet.Data.data() + (Packet.CurrentExportStart + 4), + Packet.Data.size() - (Packet.CurrentExportStart + 4)); + + const int InteriorExportIndex = Packet.CompleteExport(BlobType_ChunkInteriorV2, InteriorImports); + + Info.RootExportHash = InteriorHash; + + // Add import for directory to reference this interior node + ExtendableStringBuilder<32> Fragment; + Fragment << "exp=" << InteriorExportIndex; + Info.DirectoryExportImportIndex = Packet.AddImport(CurrentPacketBaseIdx, std::string(Fragment.ToView())); + } + } + + // Create directory node export + // Payload: [flags(varint=0)] [file_count(varint)] [file_entries...] [dir_count(varint=0)] + // FileEntry: [import(varint)] [IoHash(20)] [name(string)] [flags(varint)] [length(varint)] [IoHash_stream(20)] + + Packet.BeginExport(); + + // Build directory payload into a temporary buffer, then write it + std::vector DirPayload; + WriteVarInt(DirPayload, 0); // flags + WriteVarInt(DirPayload, ValidFiles.size()); // file_count + + std::vector DirImports; + for (size_t i = 0; i < ValidFiles.size(); ++i) + { + FileInfo& Info = ValidFiles[i]; + DirImports.push_back(Info.DirectoryExportImportIndex); + + // IoHash of target (20 bytes) — import is consumed sequentially from the + // export's import list by ReadBlobRef, not encoded in the payload + WriteBytes(DirPayload, Info.RootExportHash.Hash, sizeof(IoHash)); + // name (string) + WriteString(DirPayload, Info.Name); + // flags (varint): 1 = Executable + WriteVarInt(DirPayload, 1); + // length (varint) + WriteVarInt(DirPayload, static_cast(Info.FileSize)); + // stream hash: IoHash from full BLAKE3, truncated to 20 bytes + const IoHash StreamIoHash = IoHash::FromBLAKE3(Info.StreamHash); + WriteBytes(DirPayload, StreamIoHash.Hash, sizeof(IoHash)); + } + + WriteVarInt(DirPayload, 0); // dir_count + + Packet.WritePayload(DirPayload.data(), DirPayload.size()); + const int DirExportIndex = Packet.CompleteExport(BlobType_DirectoryV1, DirImports); + + // Finalize packet and encode + std::vector UncompressedPacket = Packet.Finish(); + std::vector EncodedPacket = EncodePacket(std::move(UncompressedPacket)); + + // Write .blob file + const std::filesystem::path BlobFilePath = OutputDir / (BlobName + ".blob"); + { + BasicFile BlobFile(BlobFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create blob file: {}", BlobFilePath.string()); + return false; + } + BlobFile.Write(EncodedPacket.data(), EncodedPacket.size(), 0); + } + + // Build locator: #pkt=0,&exp= + ExtendableStringBuilder<256> Locator; + Locator << BlobName << "#pkt=0," << uint64_t(EncodedPacket.size()) << "&exp=" << DirExportIndex; + const std::string LocatorStr(Locator.ToView()); + + // Write .ref file (use first file's name as the ref base) + const std::filesystem::path RefFilePath = OutputDir / (ValidFiles[0].Name + ".Bundle.ref"); + { + BasicFile RefFile(RefFilePath, BasicFile::Mode::kTruncate, Ec); + if (Ec) + { + ZEN_ERROR("failed to create ref file: {}", RefFilePath.string()); + return false; + } + RefFile.Write(LocatorStr.data(), LocatorStr.size(), 0); + } + + OutResult.Locator = LocatorStr; + OutResult.BundleDir = OutputDir; + + ZEN_INFO("created V2 bundle: blob={}.blob locator={} files={}", BlobName, LocatorStr, ValidFiles.size()); + return true; +} + +bool +BundleCreator::ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator) +{ + BasicFile File; + std::error_code Ec; + File.Open(RefFile, BasicFile::Mode::kRead, Ec); + if (Ec) + { + return false; + } + + IoBuffer Content = File.ReadAll(); + OutLocator.assign(static_cast(Content.GetData()), Content.GetSize()); + + // Strip trailing whitespace/newlines + while (!OutLocator.empty() && (OutLocator.back() == '\n' || OutLocator.back() == '\r' || OutLocator.back() == '\0')) + { + OutLocator.pop_back(); + } + + return !OutLocator.empty(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordebundle.h b/src/zenhorde/hordebundle.h new file mode 100644 index 000000000..052f60435 --- /dev/null +++ b/src/zenhorde/hordebundle.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include + +namespace zen::horde { + +/** Describes a file to include in a Horde bundle. */ +struct BundleFile +{ + std::filesystem::path Path; ///< Local file path + bool Optional; ///< If true, skip without error if missing +}; + +/** Result of a successful bundle creation. */ +struct BundleResult +{ + std::string Locator; ///< Root directory locator for WriteFiles + std::filesystem::path BundleDir; ///< Directory containing .blob files +}; + +/** Creates Horde V2 bundles from local files for upload to remote agents. + * + * Produces a proper Horde storage V2 bundle containing: + * - Chunk leaf exports for file data (split into 64KB chunks for large files) + * - Optional interior chunk nodes referencing leaf chunks + * - A directory node listing all bundled files with metadata + * + * The bundle is written as a single .blob file with a corresponding .ref file + * containing the locator string. The locator format is: + * #pkt=0,&exp= + */ +struct BundleCreator +{ + /** Create a V2 bundle from one or more input files. + * @param Files Files to include in the bundle. + * @param OutputDir Directory where .blob and .ref files will be written. + * @param OutResult Receives the locator and output directory on success. + * @return True on success. */ + static bool CreateBundle(const std::vector& Files, const std::filesystem::path& OutputDir, BundleResult& OutResult); + + /** Read a locator string from a .ref file. Strips trailing whitespace/newlines. */ + static bool ReadLocator(const std::filesystem::path& RefFile, std::string& OutLocator); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeclient.cpp b/src/zenhorde/hordeclient.cpp new file mode 100644 index 000000000..fb981f0ba --- /dev/null +++ b/src/zenhorde/hordeclient.cpp @@ -0,0 +1,382 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::horde { + +HordeClient::HordeClient(const HordeConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("horde.client")) +{ +} + +HordeClient::~HordeClient() = default; + +bool +HordeClient::Initialize() +{ + ZEN_TRACE_CPU("HordeClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "horde.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + Settings.ExpectedErrorCodes = {HttpResponseCode::ServiceUnavailable, HttpResponseCode::TooManyRequests}; + + if (!m_Config.AuthToken.empty()) + { + Settings.AccessTokenProvider = [token = m_Config.AuthToken]() -> HttpClientAccessToken { + HttpClientAccessToken Token; + Token.Value = token; + Token.ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours{24}; + return Token; + }; + } + + m_Http = std::make_unique(m_Config.ServerUrl, Settings); + + if (!m_Config.AuthToken.empty()) + { + if (!m_Http->Authenticate()) + { + ZEN_WARN("failed to authenticate with Horde server"); + return false; + } + } + + return true; +} + +std::string +HordeClient::BuildRequestBody() const +{ + json11::Json::object Requirements; + + if (m_Config.Mode == ConnectionMode::Direct && !m_Config.Pool.empty()) + { + Requirements["pool"] = m_Config.Pool; + } + + std::string Condition; +#if ZEN_PLATFORM_WINDOWS + ExtendableStringBuilder<256> CondBuf; + CondBuf << "(OSFamily == 'Windows' || WineEnabled == '" << (m_Config.AllowWine ? "true" : "false") << "')"; + Condition = std::string(CondBuf); +#elif ZEN_PLATFORM_MAC + Condition = "OSFamily == 'MacOS'"; +#else + Condition = "OSFamily == 'Linux'"; +#endif + + if (!m_Config.Condition.empty()) + { + Condition += " "; + Condition += m_Config.Condition; + } + + Requirements["condition"] = Condition; + Requirements["exclusive"] = true; + + json11::Json::object Connection; + Connection["modePreference"] = ToString(m_Config.Mode); + + if (m_Config.EncryptionMode != Encryption::None) + { + Connection["encryption"] = ToString(m_Config.EncryptionMode); + } + + // Request configured zen service port to be forwarded. The Horde agent will map this + // to a local port on the provisioned machine and report it back in the response. + json11::Json::object PortsObj; + PortsObj["ZenPort"] = json11::Json(m_Config.ZenServicePort); + Connection["ports"] = PortsObj; + + json11::Json::object Root; + Root["requirements"] = Requirements; + Root["connection"] = Connection; + + return json11::Json(Root).dump(); +} + +bool +HordeClient::ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster) +{ + ZEN_TRACE_CPU("HordeClient::ResolveCluster"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Post("api/v2/compute/_cluster", Payload); + + if (Response.Error) + { + ZEN_WARN("cluster resolution failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast(Response.StatusCode); + + if (StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("cluster resolution returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("cluster resolution returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("cluster resolution failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for cluster resolution: {}", Err); + return false; + } + + const json11::Json ClusterIdVal = Json["clusterId"]; + if (!ClusterIdVal.is_string() || ClusterIdVal.string_value().empty()) + { + ZEN_WARN("missing 'clusterId' in cluster resolution response"); + return false; + } + + OutCluster.ClusterId = ClusterIdVal.string_value(); + return true; +} + +bool +HordeClient::ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize) +{ + if (Hex.size() != OutSize * 2) + { + return false; + } + + for (size_t i = 0; i < OutSize; ++i) + { + auto HexToByte = [](char c) -> int { + if (c >= '0' && c <= '9') + return c - '0'; + if (c >= 'a' && c <= 'f') + return c - 'a' + 10; + if (c >= 'A' && c <= 'F') + return c - 'A' + 10; + return -1; + }; + + const int Hi = HexToByte(Hex[i * 2]); + const int Lo = HexToByte(Hex[i * 2 + 1]); + if (Hi < 0 || Lo < 0) + { + return false; + } + Out[i] = static_cast((Hi << 4) | Lo); + } + + return true; +} + +bool +HordeClient::RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine) +{ + ZEN_TRACE_CPU("HordeClient::RequestMachine"); + + ZEN_INFO("requesting machine from Horde with cluster '{}'", ClusterId.empty() ? "default" : ClusterId.c_str()); + + ExtendableStringBuilder<128> ResourcePath; + ResourcePath << "api/v2/compute/" << (ClusterId.empty() ? "default" : ClusterId.c_str()); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{RequestBody.data(), RequestBody.size()}, ZenContentType::kJSON); + const HttpClient::Response Response = m_Http->Post(ResourcePath.ToView(), Payload); + + // Reset output to invalid state + OutMachine = {}; + OutMachine.Port = 0xFFFF; + + if (Response.Error) + { + ZEN_WARN("machine request failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast(Response.StatusCode); + + if (StatusCode == 404 || StatusCode == 503 || StatusCode == 429) + { + ZEN_DEBUG("machine request returned HTTP/{}: no resources", StatusCode); + return false; + } + + if (StatusCode == 401) + { + ZEN_WARN("machine request returned HTTP/401: token expired"); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("machine request failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response for machine request: {}", Err); + return false; + } + + // Required fields + const json11::Json NonceVal = Json["nonce"]; + const json11::Json IpVal = Json["ip"]; + const json11::Json PortVal = Json["port"]; + + if (!NonceVal.is_string() || !IpVal.is_string() || !PortVal.is_number()) + { + ZEN_WARN("missing 'nonce', 'ip', or 'port' in machine response"); + return false; + } + + OutMachine.Ip = IpVal.string_value(); + OutMachine.Port = static_cast(PortVal.int_value()); + + if (!ParseHexBytes(NonceVal.string_value(), OutMachine.Nonce, NonceSize)) + { + ZEN_WARN("invalid nonce hex string in machine response"); + return false; + } + + if (const json11::Json PortsVal = Json["ports"]; PortsVal.is_object()) + { + for (const auto& [Key, Val] : PortsVal.object_items()) + { + PortInfo Info; + if (Val["port"].is_number()) + { + Info.Port = static_cast(Val["port"].int_value()); + } + if (Val["agentPort"].is_number()) + { + Info.AgentPort = static_cast(Val["agentPort"].int_value()); + } + OutMachine.Ports[Key] = Info; + } + } + + if (const json11::Json ConnectionModeVal = Json["connectionMode"]; ConnectionModeVal.is_string()) + { + if (FromString(OutMachine.Mode, ConnectionModeVal.string_value())) + { + if (const json11::Json ConnectionAddressVal = Json["connectionAddress"]; ConnectionAddressVal.is_string()) + { + OutMachine.ConnectionAddress = ConnectionAddressVal.string_value(); + } + } + } + + // Properties are a flat string array of "Key=Value" pairs describing the machine. + // We extract OS family and core counts for sizing decisions. If neither core count + // is available, we fall back to 16 as a conservative default. + uint16_t LogicalCores = 0; + uint16_t PhysicalCores = 0; + + if (const json11::Json PropertiesVal = Json["properties"]; PropertiesVal.is_array()) + { + for (const json11::Json& PropVal : PropertiesVal.array_items()) + { + if (!PropVal.is_string()) + { + continue; + } + + const std::string Prop = PropVal.string_value(); + if (Prop.starts_with("OSFamily=")) + { + if (Prop.substr(9) == "Windows") + { + OutMachine.IsWindows = true; + } + } + else if (Prop.starts_with("LogicalCores=")) + { + LogicalCores = static_cast(std::atoi(Prop.c_str() + 13)); + } + else if (Prop.starts_with("PhysicalCores=")) + { + PhysicalCores = static_cast(std::atoi(Prop.c_str() + 14)); + } + } + } + + if (LogicalCores > 0) + { + OutMachine.LogicalCores = LogicalCores; + } + else if (PhysicalCores > 0) + { + OutMachine.LogicalCores = PhysicalCores * 2; + } + else + { + OutMachine.LogicalCores = 16; + } + + if (const json11::Json EncryptionVal = Json["encryption"]; EncryptionVal.is_string()) + { + if (FromString(OutMachine.EncryptionMode, EncryptionVal.string_value())) + { + if (OutMachine.EncryptionMode == Encryption::AES) + { + const json11::Json KeyVal = Json["key"]; + if (KeyVal.is_string() && !KeyVal.string_value().empty()) + { + if (!ParseHexBytes(KeyVal.string_value(), OutMachine.Key, KeySize)) + { + ZEN_WARN("invalid AES key in machine response"); + } + } + else + { + ZEN_WARN("AES encryption requested but no key provided"); + } + } + } + } + + if (const json11::Json LeaseIdVal = Json["leaseId"]; LeaseIdVal.is_string()) + { + OutMachine.LeaseId = LeaseIdVal.string_value(); + } + + ZEN_INFO("Horde machine assigned [{}:{}] cores={} lease={}", + OutMachine.GetConnectionAddress(), + OutMachine.GetConnectionPort(), + OutMachine.LogicalCores, + OutMachine.LeaseId); + + return true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp new file mode 100644 index 000000000..0d032b5d5 --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.cpp @@ -0,0 +1,454 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputebuffer.h" + +#include +#include +#include +#include +#include + +namespace zen::horde { + +// Simplified ring buffer implementation for in-process use only. +// Uses a single contiguous buffer with write/read cursors and +// mutex+condvar for synchronization. This is simpler than the UE version +// which uses lock-free atomics and shared memory, but sufficient for our +// use case where we're the initiator side of the compute protocol. + +struct ComputeBuffer::Detail : TRefCounted +{ + std::vector Data; + size_t NumChunks = 0; + size_t ChunkLength = 0; + + // Current write state + size_t WriteChunkIdx = 0; + size_t WriteOffset = 0; + bool WriteComplete = false; + + // Current read state + size_t ReadChunkIdx = 0; + size_t ReadOffset = 0; + bool Detached = false; + + // Per-chunk written length + std::vector ChunkWrittenLength; + std::vector ChunkFinished; // Writer moved to next chunk + + std::mutex Mutex; + std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes + std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space + + bool HasWriter = false; + bool HasReader = false; + + uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } + const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } +}; + +// ComputeBuffer + +ComputeBuffer::ComputeBuffer() +{ +} +ComputeBuffer::~ComputeBuffer() +{ +} + +bool +ComputeBuffer::CreateNew(const Params& InParams) +{ + auto* NewDetail = new Detail(); + NewDetail->NumChunks = InParams.NumChunks; + NewDetail->ChunkLength = InParams.ChunkLength; + NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); + NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); + NewDetail->ChunkFinished.resize(InParams.NumChunks, false); + + m_Detail = NewDetail; + return true; +} + +void +ComputeBuffer::Close() +{ + m_Detail = nullptr; +} + +bool +ComputeBuffer::IsValid() const +{ + return static_cast(m_Detail); +} + +ComputeBufferReader +ComputeBuffer::CreateReader() +{ + assert(m_Detail); + m_Detail->HasReader = true; + return ComputeBufferReader(m_Detail); +} + +ComputeBufferWriter +ComputeBuffer::CreateWriter() +{ + assert(m_Detail); + m_Detail->HasWriter = true; + return ComputeBufferWriter(m_Detail); +} + +// ComputeBufferReader + +ComputeBufferReader::ComputeBufferReader() +{ +} +ComputeBufferReader::~ComputeBufferReader() +{ +} + +ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; +ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; +ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; +ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; + +ComputeBufferReader::ComputeBufferReader(Ref InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferReader::Close() +{ + m_Detail = nullptr; +} + +void +ComputeBufferReader::Detach() +{ + if (m_Detail) + { + std::lock_guard Lock(m_Detail->Mutex); + m_Detail->Detached = true; + m_Detail->ReadCV.notify_all(); + } +} + +bool +ComputeBufferReader::IsValid() const +{ + return static_cast(m_Detail); +} + +bool +ComputeBufferReader::IsComplete() const +{ + if (!m_Detail) + { + return true; + } + std::lock_guard Lock(m_Detail->Mutex); + if (m_Detail->Detached) + { + return true; + } + return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && + m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; +} + +void +ComputeBufferReader::AdvanceReadPosition(size_t Size) +{ + if (!m_Detail) + { + return; + } + + std::lock_guard Lock(m_Detail->Mutex); + + m_Detail->ReadOffset += Size; + + // Check if we need to move to next chunk + const size_t ReadChunk = m_Detail->ReadChunkIdx; + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + } + + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferReader::GetMaxReadSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard Lock(m_Detail->Mutex); + const size_t ReadChunk = m_Detail->ReadChunkIdx; + return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; +} + +const uint8_t* +ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock Lock(m_Detail->Mutex); + + auto Predicate = [&]() -> bool { + if (m_Detail->Detached) + { + return true; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available >= MinSize) + { + return true; + } + + // If chunk is finished and we've read everything, try to move to next + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + // Move to next chunk + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + return false; // Re-check with new chunk + } + + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + + return false; + }; + + if (TimeoutMs < 0) + { + m_Detail->ReadCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + if (OutTimedOut) + { + *OutTimedOut = true; + } + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available < MinSize) + { + return nullptr; // End of stream + } + + return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; +} + +size_t +ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) +{ + const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); + if (!Data) + { + return 0; + } + + const size_t Available = GetMaxReadSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Buffer, Data, ToCopy); + AdvanceReadPosition(ToCopy); + return ToCopy; +} + +// ComputeBufferWriter + +ComputeBufferWriter::ComputeBufferWriter() = default; +ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; +ComputeBufferWriter::~ComputeBufferWriter() = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; + +ComputeBufferWriter::ComputeBufferWriter(Ref InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferWriter::Close() +{ + if (m_Detail) + { + { + std::lock_guard Lock(m_Detail->Mutex); + if (!m_Detail->WriteComplete) + { + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } + } + m_Detail = nullptr; + } +} + +bool +ComputeBufferWriter::IsValid() const +{ + return static_cast(m_Detail); +} + +void +ComputeBufferWriter::MarkComplete() +{ + if (m_Detail) + { + std::lock_guard Lock(m_Detail->Mutex); + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } +} + +void +ComputeBufferWriter::AdvanceWritePosition(size_t Size) +{ + if (!m_Detail || Size == 0) + { + return; + } + + std::lock_guard Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + m_Detail->ChunkWrittenLength[WriteChunk] += Size; + m_Detail->WriteOffset += Size; + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferWriter::GetMaxWriteSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; +} + +size_t +ComputeBufferWriter::GetChunkMaxLength() const +{ + if (!m_Detail) + { + return 0; + } + return m_Detail->ChunkLength; +} + +size_t +ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) +{ + uint8_t* Dest = WaitToWrite(1, TimeoutMs); + if (!Dest) + { + return 0; + } + + const size_t Available = GetMaxWriteSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Dest, Buffer, ToCopy); + AdvanceWritePosition(ToCopy); + return ToCopy; +} + +uint8_t* +ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock Lock(m_Detail->Mutex); + + if (m_Detail->WriteComplete) + { + return nullptr; + } + + const size_t WriteChunk = m_Detail->WriteChunkIdx; + const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; + + // If current chunk has enough space, return pointer + if (Available >= MinSize) + { + return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; + } + + // Current chunk is full - mark it as finished and move to next. + // The writer cannot advance until the reader has fully consumed the next chunk, + // preventing the writer from overwriting data the reader hasn't processed yet. + m_Detail->ChunkFinished[WriteChunk] = true; + m_Detail->ReadCV.notify_all(); + + const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; + + // Wait until reader has consumed the next chunk + auto Predicate = [&]() -> bool { + // Check if read has moved past this chunk + return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; + }; + + if (TimeoutMs < 0) + { + m_Detail->WriteCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + // Reset next chunk + m_Detail->ChunkWrittenLength[NextChunk] = 0; + m_Detail->ChunkFinished[NextChunk] = false; + m_Detail->WriteChunkIdx = NextChunk; + m_Detail->WriteOffset = 0; + + return m_Detail->ChunkPtr(NextChunk); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputebuffer.h b/src/zenhorde/hordecomputebuffer.h new file mode 100644 index 000000000..64ef91b7a --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.h @@ -0,0 +1,136 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace zen::horde { + +class ComputeBufferReader; +class ComputeBufferWriter; + +/** Simplified in-process ring buffer for the Horde compute protocol. + * + * Unlike the UE FComputeBuffer which supports shared-memory and memory-mapped files, + * this implementation uses plain heap-allocated memory since we only need in-process + * communication between channel and transport threads. The buffer is divided into + * fixed-size chunks; readers and writers block when no space is available. + */ +class ComputeBuffer +{ +public: + struct Params + { + size_t NumChunks = 2; + size_t ChunkLength = 512 * 1024; + }; + + ComputeBuffer(); + ~ComputeBuffer(); + + ComputeBuffer(const ComputeBuffer&) = delete; + ComputeBuffer& operator=(const ComputeBuffer&) = delete; + + bool CreateNew(const Params& InParams); + void Close(); + + bool IsValid() const; + + ComputeBufferReader CreateReader(); + ComputeBufferWriter CreateWriter(); + +private: + struct Detail; + Ref m_Detail; + + friend class ComputeBufferReader; + friend class ComputeBufferWriter; +}; + +/** Read endpoint for a ComputeBuffer. + * + * Provides blocking reads from the ring buffer. WaitToRead() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceReadPosition() after consuming the data. + */ +class ComputeBufferReader +{ +public: + ComputeBufferReader(); + ComputeBufferReader(const ComputeBufferReader&); + ComputeBufferReader(ComputeBufferReader&&) noexcept; + ~ComputeBufferReader(); + + ComputeBufferReader& operator=(const ComputeBufferReader&); + ComputeBufferReader& operator=(ComputeBufferReader&&) noexcept; + + void Close(); + void Detach(); + bool IsValid() const; + bool IsComplete() const; + + void AdvanceReadPosition(size_t Size); + size_t GetMaxReadSize() const; + + /** Copy up to MaxSize bytes from the buffer into Buffer. Blocks until data is available. */ + size_t Read(void* Buffer, size_t MaxSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + + /** Wait until at least MinSize bytes are available and return a direct pointer. + * Returns nullptr on timeout or if the writer has completed. */ + const uint8_t* WaitToRead(size_t MinSize, int TimeoutMs = -1, bool* OutTimedOut = nullptr); + +private: + friend class ComputeBuffer; + explicit ComputeBufferReader(Ref InDetail); + + Ref m_Detail; +}; + +/** Write endpoint for a ComputeBuffer. + * + * Provides blocking writes into the ring buffer. WaitToWrite() returns a pointer + * directly into the buffer memory (zero-copy); the caller must call + * AdvanceWritePosition() after filling the data. Call MarkComplete() to signal + * that no more data will be written. + */ +class ComputeBufferWriter +{ +public: + ComputeBufferWriter(); + ComputeBufferWriter(const ComputeBufferWriter&); + ComputeBufferWriter(ComputeBufferWriter&&) noexcept; + ~ComputeBufferWriter(); + + ComputeBufferWriter& operator=(const ComputeBufferWriter&); + ComputeBufferWriter& operator=(ComputeBufferWriter&&) noexcept; + + void Close(); + bool IsValid() const; + + /** Signal that no more data will be written. Unblocks any waiting readers. */ + void MarkComplete(); + + void AdvanceWritePosition(size_t Size); + size_t GetMaxWriteSize() const; + size_t GetChunkMaxLength() const; + + /** Copy up to MaxSize bytes from Buffer into the ring buffer. Blocks until space is available. */ + size_t Write(const void* Buffer, size_t MaxSize, int TimeoutMs = -1); + + /** Wait until at least MinSize bytes of write space are available and return a direct pointer. + * Returns nullptr on timeout. */ + uint8_t* WaitToWrite(size_t MinSize, int TimeoutMs = -1); + +private: + friend class ComputeBuffer; + explicit ComputeBufferWriter(Ref InDetail); + + Ref m_Detail; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.cpp b/src/zenhorde/hordecomputechannel.cpp new file mode 100644 index 000000000..ee2a6f327 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.cpp @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputechannel.h" + +namespace zen::horde { + +ComputeChannel::ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter) +: Reader(std::move(InReader)) +, Writer(std::move(InWriter)) +{ +} + +bool +ComputeChannel::IsValid() const +{ + return Reader.IsValid() && Writer.IsValid(); +} + +size_t +ComputeChannel::Send(const void* Data, size_t Size, int TimeoutMs) +{ + return Writer.Write(Data, Size, TimeoutMs); +} + +size_t +ComputeChannel::Recv(void* Data, size_t Size, int TimeoutMs) +{ + return Reader.Read(Data, Size, TimeoutMs); +} + +void +ComputeChannel::MarkComplete() +{ + Writer.MarkComplete(); +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputechannel.h b/src/zenhorde/hordecomputechannel.h new file mode 100644 index 000000000..c1dff20e4 --- /dev/null +++ b/src/zenhorde/hordecomputechannel.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" + +namespace zen::horde { + +/** Bidirectional communication channel using a pair of compute buffers. + * + * Pairs a ComputeBufferReader (for receiving data) with a ComputeBufferWriter + * (for sending data). Used by ComputeSocket to represent one logical channel + * within a multiplexed connection. + */ +class ComputeChannel : public TRefCounted +{ +public: + ComputeBufferReader Reader; + ComputeBufferWriter Writer; + + ComputeChannel(ComputeBufferReader InReader, ComputeBufferWriter InWriter); + + bool IsValid() const; + + size_t Send(const void* Data, size_t Size, int TimeoutMs = -1); + size_t Recv(void* Data, size_t Size, int TimeoutMs = -1); + + /** Signal that no more data will be sent on this channel. */ + void MarkComplete(); +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp new file mode 100644 index 000000000..6ef67760c --- /dev/null +++ b/src/zenhorde/hordecomputesocket.cpp @@ -0,0 +1,204 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputesocket.h" + +#include + +namespace zen::horde { + +ComputeSocket::ComputeSocket(std::unique_ptr Transport) +: m_Log(zen::logging::Get("horde.socket")) +, m_Transport(std::move(Transport)) +{ +} + +ComputeSocket::~ComputeSocket() +{ + // Shutdown order matters: first stop the ping thread, then unblock send threads + // by detaching readers, then join send threads, and finally close the transport + // to unblock the recv thread (which is blocked on RecvMessage). + { + std::lock_guard Lock(m_PingMutex); + m_PingShouldStop = true; + m_PingCV.notify_all(); + } + + for (auto& Reader : m_Readers) + { + Reader.Detach(); + } + + for (auto& [Id, Thread] : m_SendThreads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Transport->Close(); + + if (m_RecvThread.joinable()) + { + m_RecvThread.join(); + } + if (m_PingThread.joinable()) + { + m_PingThread.join(); + } +} + +Ref +ComputeSocket::CreateChannel(int ChannelId) +{ + ComputeBuffer::Params Params; + + ComputeBuffer RecvBuffer; + if (!RecvBuffer.CreateNew(Params)) + { + return {}; + } + + ComputeBuffer SendBuffer; + if (!SendBuffer.CreateNew(Params)) + { + return {}; + } + + Ref Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); + + // Attach recv buffer writer (transport recv thread writes into this) + { + std::lock_guard Lock(m_WritersMutex); + m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); + } + + // Attach send buffer reader (send thread reads from this) + { + ComputeBufferReader Reader = SendBuffer.CreateReader(); + m_Readers.push_back(Reader); + m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); + } + + return Channel; +} + +void +ComputeSocket::StartCommunication() +{ + m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); + m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); +} + +void +ComputeSocket::PingThreadProc() +{ + while (true) + { + { + std::unique_lock Lock(m_PingMutex); + if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + { + break; + } + } + + std::lock_guard Lock(m_SendMutex); + FrameHeader Header; + Header.Channel = 0; + Header.Size = ControlPing; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +void +ComputeSocket::RecvThreadProc() +{ + // Writers are cached locally to avoid taking m_WritersMutex on every frame. + // The shared m_Writers map is only accessed when a channel is seen for the first time. + std::unordered_map CachedWriters; + + FrameHeader Header; + while (m_Transport->RecvMessage(&Header, sizeof(Header))) + { + if (Header.Size >= 0) + { + // Data frame + auto It = CachedWriters.find(Header.Channel); + if (It == CachedWriters.end()) + { + std::lock_guard Lock(m_WritersMutex); + auto WIt = m_Writers.find(Header.Channel); + if (WIt == m_Writers.end()) + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + // Skip the data + std::vector Discard(Header.Size); + m_Transport->RecvMessage(Discard.data(), Header.Size); + continue; + } + It = CachedWriters.emplace(Header.Channel, WIt->second).first; + } + + ComputeBufferWriter& Writer = It->second; + uint8_t* Dest = Writer.WaitToWrite(Header.Size); + if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) + { + ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); + return; + } + Writer.AdvanceWritePosition(Header.Size); + } + else if (Header.Size == ControlDetach) + { + // Detach the recv buffer for this channel + CachedWriters.erase(Header.Channel); + + std::lock_guard Lock(m_WritersMutex); + auto It = m_Writers.find(Header.Channel); + if (It != m_Writers.end()) + { + It->second.MarkComplete(); + m_Writers.erase(It); + } + } + else if (Header.Size == ControlPing) + { + // Ping response - ignore + } + else + { + ZEN_WARN("invalid frame header size: {}", Header.Size); + return; + } + } +} + +void +ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +{ + // Each channel has its own send thread. All send threads share m_SendMutex + // to serialize writes to the transport, since TCP requires atomic frame writes. + FrameHeader Header; + Header.Channel = Channel; + + const uint8_t* Data; + while ((Data = Reader.WaitToRead(1)) != nullptr) + { + std::lock_guard Lock(m_SendMutex); + + Header.Size = static_cast(Reader.GetMaxReadSize()); + m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->SendMessage(Data, Header.Size); + Reader.AdvanceReadPosition(Header.Size); + } + + if (Reader.IsComplete()) + { + std::lock_guard Lock(m_SendMutex); + Header.Size = ControlDetach; + m_Transport->SendMessage(&Header, sizeof(Header)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordecomputesocket.h b/src/zenhorde/hordecomputesocket.h new file mode 100644 index 000000000..0c3cb4195 --- /dev/null +++ b/src/zenhorde/hordecomputesocket.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordecomputebuffer.h" +#include "hordecomputechannel.h" +#include "hordetransport.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace zen::horde { + +/** Multiplexed socket that routes data between multiple channels over a single transport. + * + * Each channel is identified by an integer ID and backed by a pair of ComputeBuffers. + * A recv thread demultiplexes incoming frames to channel-specific buffers, while + * per-channel send threads multiplex outgoing data onto the shared transport. + * + * Wire format per frame: [channelId (4B)][size (4B)][data] + * Control messages use negative sizes: -2 = detach (channel closed), -3 = ping. + */ +class ComputeSocket +{ +public: + explicit ComputeSocket(std::unique_ptr Transport); + ~ComputeSocket(); + + ComputeSocket(const ComputeSocket&) = delete; + ComputeSocket& operator=(const ComputeSocket&) = delete; + + /** Create a channel with the given ID. + * Allocates anonymous in-process buffers and spawns a send thread for the channel. */ + Ref CreateChannel(int ChannelId); + + /** Start the recv pump and ping threads. Must be called after all channels are created. */ + void StartCommunication(); + +private: + struct FrameHeader + { + int32_t Channel = 0; + int32_t Size = 0; + }; + + static constexpr int32_t ControlDetach = -2; + static constexpr int32_t ControlPing = -3; + + LoggerRef Log() { return m_Log; } + + void RecvThreadProc(); + void SendThreadProc(int Channel, ComputeBufferReader Reader); + void PingThreadProc(); + + LoggerRef m_Log; + std::unique_ptr m_Transport; + std::mutex m_SendMutex; ///< Serializes writes to the transport + + std::mutex m_WritersMutex; + std::unordered_map m_Writers; ///< Recv-side: writers keyed by channel ID + + std::vector m_Readers; ///< Send-side: readers for join on destruction + std::unordered_map m_SendThreads; ///< One send thread per channel + + std::thread m_RecvThread; + std::thread m_PingThread; + + bool m_PingShouldStop = false; + std::mutex m_PingMutex; + std::condition_variable m_PingCV; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordeconfig.cpp b/src/zenhorde/hordeconfig.cpp new file mode 100644 index 000000000..2dca228d9 --- /dev/null +++ b/src/zenhorde/hordeconfig.cpp @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +namespace zen::horde { + +bool +HordeConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + // Relay mode implies AES encryption + if (Mode == ConnectionMode::Relay && EncryptionMode != Encryption::AES) + { + return false; + } + + return true; +} + +const char* +ToString(ConnectionMode Mode) +{ + switch (Mode) + { + case ConnectionMode::Direct: + return "direct"; + case ConnectionMode::Tunnel: + return "tunnel"; + case ConnectionMode::Relay: + return "relay"; + } + return "direct"; +} + +const char* +ToString(Encryption Enc) +{ + switch (Enc) + { + case Encryption::None: + return "none"; + case Encryption::AES: + return "aes"; + } + return "none"; +} + +bool +FromString(ConnectionMode& OutMode, std::string_view Str) +{ + if (Str == "direct") + { + OutMode = ConnectionMode::Direct; + return true; + } + if (Str == "tunnel") + { + OutMode = ConnectionMode::Tunnel; + return true; + } + if (Str == "relay") + { + OutMode = ConnectionMode::Relay; + return true; + } + return false; +} + +bool +FromString(Encryption& OutEnc, std::string_view Str) +{ + if (Str == "none") + { + OutEnc = Encryption::None; + return true; + } + if (Str == "aes") + { + OutEnc = Encryption::AES; + return true; + } + return false; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordeprovisioner.cpp b/src/zenhorde/hordeprovisioner.cpp new file mode 100644 index 000000000..f88c95da2 --- /dev/null +++ b/src/zenhorde/hordeprovisioner.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include + +#include "hordeagent.h" +#include "hordebundle.h" + +#include +#include +#include +#include +#include + +#include +#include + +namespace zen::horde { + +struct HordeProvisioner::AgentWrapper +{ + std::thread Thread; + std::atomic ShouldExit{false}; +}; + +HordeProvisioner::HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_BinariesPath(BinariesPath) +, m_WorkingDir(WorkingDir) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_Log(zen::logging::Get("horde.provisioner")) +{ +} + +HordeProvisioner::~HordeProvisioner() +{ + std::lock_guard Lock(m_AgentsLock); + for (auto& Agent : m_Agents) + { + Agent->ShouldExit.store(true); + } + for (auto& Agent : m_Agents) + { + if (Agent->Thread.joinable()) + { + Agent->Thread.join(); + } + } +} + +void +HordeProvisioner::SetTargetCoreCount(uint32_t Count) +{ + ZEN_TRACE_CPU("HordeProvisioner::SetTargetCoreCount"); + + m_TargetCoreCount.store(std::min(Count, static_cast(m_Config.MaxCores))); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + if (!m_AskForAgents.load()) + { + return; + } + RequestAgent(); + } + + // Clean up finished agent threads + std::lock_guard Lock(m_AgentsLock); + for (auto It = m_Agents.begin(); It != m_Agents.end();) + { + if ((*It)->ShouldExit.load()) + { + if ((*It)->Thread.joinable()) + { + (*It)->Thread.join(); + } + It = m_Agents.erase(It); + } + else + { + ++It; + } + } +} + +ProvisioningStats +HordeProvisioner::GetStats() const +{ + ProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.ActiveCoreCount = m_ActiveCoreCount.load(); + Stats.AgentsActive = m_AgentsActive.load(); + Stats.AgentsRequesting = m_AgentsRequesting.load(); + return Stats; +} + +uint32_t +HordeProvisioner::GetAgentCount() const +{ + std::lock_guard Lock(m_AgentsLock); + return static_cast(m_Agents.size()); +} + +void +HordeProvisioner::RequestAgent() +{ + m_EstimatedCoreCount.fetch_add(EstimatedCoresPerAgent); + + std::lock_guard Lock(m_AgentsLock); + + auto Wrapper = std::make_unique(); + AgentWrapper& Ref = *Wrapper; + Wrapper->Thread = std::thread([this, &Ref] { ThreadAgent(Ref); }); + + m_Agents.push_back(std::move(Wrapper)); +} + +void +HordeProvisioner::ThreadAgent(AgentWrapper& Wrapper) +{ + ZEN_TRACE_CPU("HordeProvisioner::ThreadAgent"); + + static std::atomic ThreadIndex{0}; + const uint32_t CurrentIndex = ThreadIndex.fetch_add(1); + + zen::SetCurrentThreadName(fmt::format("horde_agent_{}", CurrentIndex)); + + std::unique_ptr Agent; + uint32_t MachineCoreCount = 0; + + auto _ = MakeGuard([&] { + if (Agent) + { + Agent->CloseConnection(); + } + Wrapper.ShouldExit.store(true); + }); + + { + // EstimatedCoreCount is incremented speculatively when the agent is requested + // (in RequestAgent) so that SetTargetCoreCount doesn't over-provision. + auto $ = MakeGuard([&] { m_EstimatedCoreCount.fetch_sub(EstimatedCoresPerAgent); }); + + { + ZEN_TRACE_CPU("HordeProvisioner::CreateBundles"); + + std::lock_guard BundleLock(m_BundleLock); + + if (!m_BundlesCreated) + { + const std::filesystem::path OutputDir = m_WorkingDir / "horde_bundles"; + + std::vector Files; + +#if ZEN_PLATFORM_WINDOWS + Files.emplace_back(m_BinariesPath / "zenserver.exe", false); +#elif ZEN_PLATFORM_LINUX + Files.emplace_back(m_BinariesPath / "zenserver", false); + Files.emplace_back(m_BinariesPath / "zenserver.debug", true); +#elif ZEN_PLATFORM_MAC + Files.emplace_back(m_BinariesPath / "zenserver", false); +#endif + + BundleResult Result; + if (!BundleCreator::CreateBundle(Files, OutputDir, Result)) + { + ZEN_WARN("failed to create bundle, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + + m_Bundles.emplace_back(Result.Locator, Result.BundleDir); + m_BundlesCreated = true; + } + + if (!m_HordeClient) + { + m_HordeClient = std::make_unique(m_Config); + if (!m_HordeClient->Initialize()) + { + ZEN_WARN("failed to initialize Horde HTTP client, cannot provision any agents!"); + m_AskForAgents.store(false); + return; + } + } + } + + if (!m_AskForAgents.load()) + { + return; + } + + m_AgentsRequesting.fetch_add(1); + auto ReqGuard = MakeGuard([this] { m_AgentsRequesting.fetch_sub(1); }); + + // Simple backoff: if the last machine request failed, wait up to 5 seconds + // before trying again. + // + // Note however that it's possible that multiple threads enter this code at + // the same time if multiple agents are requested at once, and they will all + // see the same last failure time and back off accordingly. We might want to + // use a semaphore or similar to limit the number of concurrent requests. + + if (const uint64_t LastFail = m_LastRequestFailTime.load(); LastFail != 0) + { + auto Now = static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); + const uint64_t ElapsedNs = Now - LastFail; + const uint64_t ElapsedMs = ElapsedNs / 1'000'000; + if (ElapsedMs < 5000) + { + const uint64_t WaitMs = 5000 - ElapsedMs; + for (uint64_t Waited = 0; Waited < WaitMs && !Wrapper.ShouldExit.load(); Waited += 100) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + } + } + + if (m_ActiveCoreCount.load() >= m_TargetCoreCount.load()) + { + return; + } + + std::string RequestBody = m_HordeClient->BuildRequestBody(); + + // Resolve cluster if needed + std::string ClusterId = m_Config.Cluster; + if (ClusterId == HordeConfig::ClusterAuto) + { + ClusterInfo Cluster; + if (!m_HordeClient->ResolveCluster(RequestBody, Cluster)) + { + ZEN_WARN("failed to resolve cluster"); + m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + ClusterId = Cluster.ClusterId; + } + + MachineInfo Machine; + if (!m_HordeClient->RequestMachine(RequestBody, ClusterId, /* out */ Machine) || !Machine.IsValid()) + { + m_LastRequestFailTime.store(static_cast(std::chrono::steady_clock::now().time_since_epoch().count())); + return; + } + + m_LastRequestFailTime.store(0); + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Connect to agent and perform handshake + Agent = std::make_unique(Machine); + if (!Agent->IsValid()) + { + ZEN_WARN("agent creation failed for {}:{}", Machine.GetConnectionAddress(), Machine.GetConnectionPort()); + return; + } + + if (!Agent->BeginCommunication()) + { + ZEN_WARN("BeginCommunication failed"); + return; + } + + for (auto& [Locator, BundleDir] : m_Bundles) + { + if (Wrapper.ShouldExit.load()) + { + return; + } + + if (!Agent->UploadBinaries(BundleDir, Locator)) + { + ZEN_WARN("UploadBinaries failed"); + return; + } + } + + if (Wrapper.ShouldExit.load()) + { + return; + } + + // Build command line for remote zenserver + std::vector ArgStrings; + ArgStrings.push_back("compute"); + ArgStrings.push_back("--http=asio"); + + // TEMP HACK - these should be made fully dynamic + // these are currently here to allow spawning the compute agent locally + // for debugging purposes (i.e with a local Horde Server+Agent setup) + ArgStrings.push_back(fmt::format("--port={}", m_Config.ZenServicePort)); + ArgStrings.push_back("--data-dir=c:\\temp\\123"); + + if (!m_OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << m_OrchestratorEndpoint; + ArgStrings.emplace_back(CoordArg.ToView()); + } + + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=horde-" << Machine.LeaseId; + ArgStrings.emplace_back(IdArg.ToView()); + } + + std::vector Args; + Args.reserve(ArgStrings.size()); + for (const std::string& Arg : ArgStrings) + { + Args.push_back(Arg.c_str()); + } + +#if ZEN_PLATFORM_WINDOWS + const bool UseWine = !Machine.IsWindows; + const char* AppName = "zenserver.exe"; +#else + const bool UseWine = false; + const char* AppName = "zenserver"; +#endif + + Agent->Execute(AppName, Args.data(), Args.size(), nullptr, nullptr, 0, UseWine); + + ZEN_INFO("remote execution started on [{}:{}] lease={}", + Machine.GetConnectionAddress(), + Machine.GetConnectionPort(), + Machine.LeaseId); + + MachineCoreCount = Machine.LogicalCores; + m_EstimatedCoreCount.fetch_add(MachineCoreCount); + m_ActiveCoreCount.fetch_add(MachineCoreCount); + m_AgentsActive.fetch_add(1); + } + + // Agent poll loop + + auto ActiveGuard = MakeGuard([&]() { + m_EstimatedCoreCount.fetch_sub(MachineCoreCount); + m_ActiveCoreCount.fetch_sub(MachineCoreCount); + m_AgentsActive.fetch_sub(1); + }); + + while (Agent->IsValid() && !Wrapper.ShouldExit.load()) + { + const bool LogOutput = false; + if (!Agent->Poll(LogOutput)) + { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp new file mode 100644 index 000000000..69766e73e --- /dev/null +++ b/src/zenhorde/hordetransport.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransport.h" + +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +// ComputeTransport base + +bool +ComputeTransport::SendMessage(const void* Data, size_t Size) +{ + const uint8_t* Ptr = static_cast(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Sent = Send(Ptr, Remaining); + if (Sent == 0) + { + return false; + } + Ptr += Sent; + Remaining -= Sent; + } + + return true; +} + +bool +ComputeTransport::RecvMessage(void* Data, size_t Size) +{ + uint8_t* Ptr = static_cast(Data); + size_t Remaining = Size; + + while (Remaining > 0) + { + const size_t Received = Recv(Ptr, Remaining); + if (Received == 0) + { + return false; + } + Ptr += Received; + Remaining -= Received; + } + + return true; +} + +// TcpComputeTransport - ASIO pimpl + +struct TcpComputeTransport::Impl +{ + asio::io_context IoContext; + asio::ip::tcp::socket Socket; + + Impl() : Socket(IoContext) {} +}; + +// Uses ASIO in synchronous mode only — no async operations or io_context::run(). +// The io_context is only needed because ASIO sockets require one to be constructed. +TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) +: m_Impl(std::make_unique()) +, m_Log(zen::logging::Get("horde.transport")) +{ + ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + + asio::error_code Ec; + + const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec); + if (Ec) + { + ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); + m_HasErrors = true; + return; + } + + const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); + + m_Impl->Socket.connect(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); + m_HasErrors = true; + return; + } + + // Disable Nagle's algorithm for lower latency + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); +} + +TcpComputeTransport::~TcpComputeTransport() +{ + Close(); +} + +bool +TcpComputeTransport::IsValid() const +{ + return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; +} + +size_t +TcpComputeTransport::Send(const void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + m_HasErrors = true; + return 0; + } + + return Sent; +} + +size_t +TcpComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + asio::error_code Ec; + const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); + + if (Ec) + { + return 0; + } + + return Received; +} + +void +TcpComputeTransport::MarkComplete() +{ +} + +void +TcpComputeTransport::Close() +{ + if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) + { + asio::error_code Ec; + m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + m_Impl->Socket.close(Ec); + } + m_IsClosed = true; +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransport.h b/src/zenhorde/hordetransport.h new file mode 100644 index 000000000..1b178dc0f --- /dev/null +++ b/src/zenhorde/hordetransport.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# undef SendMessage +#endif + +namespace zen::horde { + +/** Abstract base interface for compute transports. + * + * Matches the UE FComputeTransport pattern. Concrete implementations handle + * the underlying I/O (TCP, AES-wrapped, etc.) while this interface provides + * blocking message helpers on top. + */ +class ComputeTransport +{ +public: + virtual ~ComputeTransport() = default; + + virtual bool IsValid() const = 0; + virtual size_t Send(const void* Data, size_t Size) = 0; + virtual size_t Recv(void* Data, size_t Size) = 0; + virtual void MarkComplete() = 0; + virtual void Close() = 0; + + /** Blocking send that loops until all bytes are transferred. Returns false on error. */ + bool SendMessage(const void* Data, size_t Size); + + /** Blocking receive that loops until all bytes are transferred. Returns false on error. */ + bool RecvMessage(void* Data, size_t Size); +}; + +/** TCP socket transport using ASIO. + * + * Connects to the Horde compute endpoint specified by MachineInfo and provides + * raw TCP send/receive. ASIO internals are hidden behind a pimpl to keep the + * header clean. + */ +class TcpComputeTransport final : public ComputeTransport +{ +public: + explicit TcpComputeTransport(const MachineInfo& Info); + ~TcpComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + LoggerRef Log() { return m_Log; } + + struct Impl; + std::unique_ptr m_Impl; + LoggerRef m_Log; + bool m_IsClosed = false; + bool m_HasErrors = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp new file mode 100644 index 000000000..986dd3705 --- /dev/null +++ b/src/zenhorde/hordetransportaes.cpp @@ -0,0 +1,425 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordetransportaes.h" + +#include +#include + +#include +#include +#include + +#if ZEN_PLATFORM_WINDOWS +# include +# include +# pragma comment(lib, "Bcrypt.lib") +#else +ZEN_THIRD_PARTY_INCLUDES_START +# include +# include +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +namespace zen::horde { + +struct AesComputeTransport::CryptoContext +{ + uint8_t Key[KeySize] = {}; + uint8_t EncryptNonce[NonceBytes] = {}; + uint8_t DecryptNonce[NonceBytes] = {}; + bool HasErrors = false; + +#if !ZEN_PLATFORM_WINDOWS + EVP_CIPHER_CTX* EncCtx = nullptr; + EVP_CIPHER_CTX* DecCtx = nullptr; +#endif + + CryptoContext(const uint8_t (&InKey)[KeySize]) + { + memcpy(Key, InKey, KeySize); + + // The encrypt nonce is randomly initialized and then deterministically mutated + // per message via UpdateNonce(). The decrypt nonce is not used — it comes from + // the wire (each received message carries its own nonce in the header). + std::random_device Rd; + std::mt19937 Gen(Rd()); + std::uniform_int_distribution Dist(0, 255); + for (auto& Byte : EncryptNonce) + { + Byte = static_cast(Dist(Gen)); + } + +#if !ZEN_PLATFORM_WINDOWS + // Drain any stale OpenSSL errors + while (ERR_get_error() != 0) + { + } + + EncCtx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(EncCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); + + DecCtx = EVP_CIPHER_CTX_new(); + EVP_DecryptInit_ex(DecCtx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr); +#endif + } + + ~CryptoContext() + { +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Key, sizeof(Key)); + SecureZeroMemory(EncryptNonce, sizeof(EncryptNonce)); + SecureZeroMemory(DecryptNonce, sizeof(DecryptNonce)); +#else + OPENSSL_cleanse(Key, sizeof(Key)); + OPENSSL_cleanse(EncryptNonce, sizeof(EncryptNonce)); + OPENSSL_cleanse(DecryptNonce, sizeof(DecryptNonce)); + + if (EncCtx) + { + EVP_CIPHER_CTX_free(EncCtx); + } + if (DecCtx) + { + EVP_CIPHER_CTX_free(DecCtx); + } +#endif + } + + void UpdateNonce() + { + uint32_t* N32 = reinterpret_cast(EncryptNonce); + N32[0]++; + N32[1]--; + N32[2] = N32[0] ^ N32[1]; + } + + // Returns total encrypted message size, or 0 on failure + // Output format: [length(4B)][nonce(12B)][ciphertext][tag(16B)] + int32_t EncryptMessage(uint8_t* Out, const void* In, int32_t InLength) + { + UpdateNonce(); + + // On Windows, BCrypt algorithm/key handles are created per call. This is simpler than + // caching but has some overhead. For our use case (relatively large, infrequent messages) + // this is acceptable. +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = EncryptNonce; + AuthInfo.cbNonce = NonceBytes; + uint8_t Tag[TagBytes] = {}; + AuthInfo.pbTag = Tag; + AuthInfo.cbTag = TagBytes; + + ULONG CipherLen = 0; + NTSTATUS Status = + BCryptEncrypt(hKey, (PUCHAR)In, (ULONG)InLength, &AuthInfo, nullptr, 0, Out + 4 + NonceBytes, (ULONG)InLength, &CipherLen, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + return 0; + } + + // Write header: length + nonce + memcpy(Out, &InLength, 4); + memcpy(Out + 4, EncryptNonce, NonceBytes); + // Write tag after ciphertext + memcpy(Out + 4 + NonceBytes + CipherLen, Tag, TagBytes); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + return 4 + NonceBytes + static_cast(CipherLen) + TagBytes; +#else + if (EVP_EncryptInit_ex(EncCtx, nullptr, nullptr, Key, EncryptNonce) != 1) + { + HasErrors = true; + return 0; + } + + int32_t Offset = 0; + // Write length + memcpy(Out + Offset, &InLength, 4); + Offset += 4; + // Write nonce + memcpy(Out + Offset, EncryptNonce, NonceBytes); + Offset += NonceBytes; + + // Encrypt + int OutLen = 0; + if (EVP_EncryptUpdate(EncCtx, Out + Offset, &OutLen, static_cast(In), InLength) != 1) + { + HasErrors = true; + return 0; + } + Offset += OutLen; + + // Finalize + int FinalLen = 0; + if (EVP_EncryptFinal_ex(EncCtx, Out + Offset, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + Offset += FinalLen; + + // Get tag + if (EVP_CIPHER_CTX_ctrl(EncCtx, EVP_CTRL_GCM_GET_TAG, TagBytes, Out + Offset) != 1) + { + HasErrors = true; + return 0; + } + Offset += TagBytes; + + return Offset; +#endif + } + + // Decrypt a message. Returns decrypted data length, or 0 on failure. + // Input must be [ciphertext][tag], with nonce provided separately. + int32_t DecryptMessage(void* Out, const uint8_t* Nonce, const uint8_t* CipherAndTag, int32_t DataLength) + { +#if ZEN_PLATFORM_WINDOWS + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_KEY_HANDLE hKey = nullptr; + + BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0); + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0); + BCryptGenerateSymmetricKey(hAlg, &hKey, nullptr, 0, (PUCHAR)Key, KeySize, 0); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = const_cast(Nonce); + AuthInfo.cbNonce = NonceBytes; + AuthInfo.pbTag = const_cast(CipherAndTag + DataLength); + AuthInfo.cbTag = TagBytes; + + ULONG PlainLen = 0; + NTSTATUS Status = BCryptDecrypt(hKey, + (PUCHAR)CipherAndTag, + (ULONG)DataLength, + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out, + (ULONG)DataLength, + &PlainLen, + 0); + + BCryptDestroyKey(hKey); + BCryptCloseAlgorithmProvider(hAlg, 0); + + if (!BCRYPT_SUCCESS(Status)) + { + HasErrors = true; + return 0; + } + + return static_cast(PlainLen); +#else + if (EVP_DecryptInit_ex(DecCtx, nullptr, nullptr, Key, Nonce) != 1) + { + HasErrors = true; + return 0; + } + + int OutLen = 0; + if (EVP_DecryptUpdate(DecCtx, static_cast(Out), &OutLen, CipherAndTag, DataLength) != 1) + { + HasErrors = true; + return 0; + } + + // Set the tag for verification + if (EVP_CIPHER_CTX_ctrl(DecCtx, EVP_CTRL_GCM_SET_TAG, TagBytes, const_cast(CipherAndTag + DataLength)) != 1) + { + HasErrors = true; + return 0; + } + + int FinalLen = 0; + if (EVP_DecryptFinal_ex(DecCtx, static_cast(Out) + OutLen, &FinalLen) != 1) + { + HasErrors = true; + return 0; + } + + return OutLen + FinalLen; +#endif + } +}; + +AesComputeTransport::AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr InnerTransport) +: m_Crypto(std::make_unique(Key)) +, m_Inner(std::move(InnerTransport)) +{ +} + +AesComputeTransport::~AesComputeTransport() +{ + Close(); +} + +bool +AesComputeTransport::IsValid() const +{ + return m_Inner && m_Inner->IsValid() && m_Crypto && !m_Crypto->HasErrors && !m_IsClosed; +} + +size_t +AesComputeTransport::Send(const void* Data, size_t Size) +{ + ZEN_TRACE_CPU("AesComputeTransport::Send"); + + if (!IsValid()) + { + return 0; + } + + std::lock_guard Lock(m_Lock); + + const int32_t DataLength = static_cast(Size); + const size_t MessageLength = 4 + NonceBytes + Size + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + const int32_t EncryptedLen = m_Crypto->EncryptMessage(m_EncryptBuffer.data(), Data, DataLength); + if (EncryptedLen == 0) + { + return 0; + } + + if (!m_Inner->SendMessage(m_EncryptBuffer.data(), static_cast(EncryptedLen))) + { + return 0; + } + + return Size; +} + +size_t +AesComputeTransport::Recv(void* Data, size_t Size) +{ + if (!IsValid()) + { + return 0; + } + + // AES-GCM decrypts entire messages at once, but the caller may request fewer bytes + // than the decrypted message contains. Excess bytes are buffered in m_RemainingData + // and returned on subsequent Recv calls without another decryption round-trip. + ZEN_TRACE_CPU("AesComputeTransport::Recv"); + + std::lock_guard Lock(m_Lock); + + if (!m_RemainingData.empty()) + { + const size_t Available = m_RemainingData.size() - m_RemainingOffset; + const size_t ToCopy = std::min(Available, Size); + + memcpy(Data, m_RemainingData.data() + m_RemainingOffset, ToCopy); + m_RemainingOffset += ToCopy; + + if (m_RemainingOffset >= m_RemainingData.size()) + { + m_RemainingData.clear(); + m_RemainingOffset = 0; + } + + return ToCopy; + } + + // Receive packet header: [length(4B)][nonce(12B)] + struct PacketHeader + { + int32_t DataLength = 0; + uint8_t Nonce[NonceBytes] = {}; + } Header; + + if (!m_Inner->RecvMessage(&Header, sizeof(Header))) + { + return 0; + } + + // Validate DataLength to prevent OOM from malicious/corrupt peers + static constexpr int32_t MaxDataLength = 64 * 1024 * 1024; // 64 MiB + + if (Header.DataLength <= 0 || Header.DataLength > MaxDataLength) + { + ZEN_WARN("AES recv: invalid DataLength {} from peer", Header.DataLength); + return 0; + } + + // Receive ciphertext + tag + const size_t MessageLength = static_cast(Header.DataLength) + TagBytes; + + if (m_EncryptBuffer.size() < MessageLength) + { + m_EncryptBuffer.resize(MessageLength); + } + + if (!m_Inner->RecvMessage(m_EncryptBuffer.data(), MessageLength)) + { + return 0; + } + + // Decrypt + const size_t BytesToReturn = std::min(static_cast(Header.DataLength), Size); + + // We need a temporary buffer for decryption if we can't decrypt directly into output + std::vector DecryptedBuf(static_cast(Header.DataLength)); + + const int32_t Decrypted = m_Crypto->DecryptMessage(DecryptedBuf.data(), Header.Nonce, m_EncryptBuffer.data(), Header.DataLength); + if (Decrypted == 0) + { + return 0; + } + + memcpy(Data, DecryptedBuf.data(), BytesToReturn); + + // Store remaining data if we couldn't return everything + if (static_cast(Header.DataLength) > BytesToReturn) + { + m_RemainingOffset = 0; + m_RemainingData.assign(DecryptedBuf.begin() + BytesToReturn, DecryptedBuf.begin() + Header.DataLength); + } + + return BytesToReturn; +} + +void +AesComputeTransport::MarkComplete() +{ + if (IsValid()) + { + m_Inner->MarkComplete(); + } +} + +void +AesComputeTransport::Close() +{ + if (!m_IsClosed) + { + if (m_Inner && m_Inner->IsValid()) + { + m_Inner->Close(); + } + m_IsClosed = true; + } +} + +} // namespace zen::horde diff --git a/src/zenhorde/hordetransportaes.h b/src/zenhorde/hordetransportaes.h new file mode 100644 index 000000000..efcad9835 --- /dev/null +++ b/src/zenhorde/hordetransportaes.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "hordetransport.h" + +#include +#include +#include +#include + +namespace zen::horde { + +/** AES-256-GCM encrypted transport wrapper. + * + * Wraps an inner ComputeTransport, encrypting all outgoing data and decrypting + * all incoming data using AES-256-GCM. The nonce is mutated per message using + * the Horde nonce mangling scheme: n32[0]++; n32[1]--; n32[2] = n32[0] ^ n32[1]. + * + * Wire format per encrypted message: + * [plaintext length (4B little-endian)][nonce (12B)][ciphertext][GCM tag (16B)] + * + * Uses BCrypt on Windows and OpenSSL EVP on Linux/macOS (selected at compile time). + */ +class AesComputeTransport final : public ComputeTransport +{ +public: + AesComputeTransport(const uint8_t (&Key)[KeySize], std::unique_ptr InnerTransport); + ~AesComputeTransport() override; + + bool IsValid() const override; + size_t Send(const void* Data, size_t Size) override; + size_t Recv(void* Data, size_t Size) override; + void MarkComplete() override; + void Close() override; + +private: + static constexpr size_t NonceBytes = 12; ///< AES-GCM nonce size + static constexpr size_t TagBytes = 16; ///< AES-GCM authentication tag size + + struct CryptoContext; + + std::unique_ptr m_Crypto; + std::unique_ptr m_Inner; + std::vector m_EncryptBuffer; + std::vector m_RemainingData; ///< Buffered decrypted data from a partially consumed Recv + size_t m_RemainingOffset = 0; + std::mutex m_Lock; + bool m_IsClosed = false; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeclient.h b/src/zenhorde/include/zenhorde/hordeclient.h new file mode 100644 index 000000000..201d68b83 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeclient.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +namespace zen { +class HttpClient; +} + +namespace zen::horde { + +static constexpr size_t NonceSize = 64; +static constexpr size_t KeySize = 32; + +/** Port mapping information returned by Horde for a provisioned machine. */ +struct PortInfo +{ + uint16_t Port = 0; + uint16_t AgentPort = 0; +}; + +/** Describes a provisioned compute machine returned by the Horde API. + * + * Contains the network address, encryption credentials, and capabilities + * needed to establish a compute transport connection to the machine. + */ +struct MachineInfo +{ + std::string Ip; + ConnectionMode Mode = ConnectionMode::Direct; + std::string ConnectionAddress; ///< Relay/tunnel address (used when Mode != Direct) + uint16_t Port = 0; + uint16_t LogicalCores = 0; + Encryption EncryptionMode = Encryption::None; + uint8_t Nonce[NonceSize] = {}; ///< 64-byte nonce sent during TCP handshake + uint8_t Key[KeySize] = {}; ///< 32-byte AES key (when EncryptionMode == AES) + bool IsWindows = false; + std::string LeaseId; + + std::map Ports; + + /** Return the address to connect to, accounting for connection mode. */ + const std::string& GetConnectionAddress() const { return Mode == ConnectionMode::Relay ? ConnectionAddress : Ip; } + + /** Return the port to connect to, accounting for connection mode and port mapping. */ + uint16_t GetConnectionPort() const + { + if (Mode == ConnectionMode::Relay) + { + auto It = Ports.find("_horde_compute"); + if (It != Ports.end()) + { + return It->second.Port; + } + } + return Port; + } + + bool IsValid() const { return !Ip.empty() && Port != 0xFFFF; } +}; + +/** Result of cluster auto-resolution via the Horde API. */ +struct ClusterInfo +{ + std::string ClusterId = "default"; +}; + +/** HTTP client for the Horde compute REST API. + * + * Handles cluster resolution and machine provisioning requests. Each call + * is synchronous and returns success/failure. Thread safety: individual + * methods are not thread-safe; callers must synchronize access. + */ +class HordeClient +{ +public: + explicit HordeClient(const HordeConfig& Config); + ~HordeClient(); + + HordeClient(const HordeClient&) = delete; + HordeClient& operator=(const HordeClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the JSON request body for cluster resolution and machine requests. + * Encodes pool, condition, connection mode, encryption, and port requirements. */ + std::string BuildRequestBody() const; + + /** Resolve the best cluster for the given request via POST /api/v2/compute/_cluster. */ + bool ResolveCluster(const std::string& RequestBody, ClusterInfo& OutCluster); + + /** Request a compute machine from the given cluster via POST /api/v2/compute/{clusterId}. + * On success, populates OutMachine with connection details and credentials. */ + bool RequestMachine(const std::string& RequestBody, const std::string& ClusterId, MachineInfo& OutMachine); + + LoggerRef Log() { return m_Log; } + +private: + bool ParseHexBytes(std::string_view Hex, uint8_t* Out, size_t OutSize); + + HordeConfig m_Config; + std::unique_ptr m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeconfig.h b/src/zenhorde/include/zenhorde/hordeconfig.h new file mode 100644 index 000000000..dd70f9832 --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeconfig.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +namespace zen::horde { + +/** Transport connection mode for Horde compute agents. */ +enum class ConnectionMode +{ + Direct, ///< Connect directly to the agent IP + Tunnel, ///< Connect through a Horde tunnel relay + Relay, ///< Connect through a Horde relay with port mapping +}; + +/** Transport encryption mode for Horde compute channels. */ +enum class Encryption +{ + None, ///< No encryption + AES, ///< AES-256-GCM encryption (required for Relay mode) +}; + +/** Configuration for connecting to an Epic Horde compute cluster. + * + * Specifies the Horde server URL, authentication token, pool selection, + * connection mode, and resource limits. Used by HordeClient and HordeProvisioner. + */ +struct HordeConfig +{ + static constexpr const char* ClusterDefault = "default"; + static constexpr const char* ClusterAuto = "_auto"; + + bool Enabled = false; ///< Whether Horde provisioning is active + std::string ServerUrl; ///< Horde server base URL (e.g. "https://horde.example.com") + std::string AuthToken; ///< Authentication token for the Horde API + std::string Pool; ///< Pool name to request machines from + std::string Cluster = ClusterDefault; ///< Cluster ID, or "_auto" to auto-resolve + std::string Condition; ///< Agent filter expression for machine selection + std::string HostAddress; ///< Address that provisioned agents use to connect back to us + std::string BinariesPath; ///< Path to directory containing zenserver binary for remote upload + uint16_t ZenServicePort = 8558; ///< Port number that provisioned agents should forward to us for Zen service communication + + int MaxCores = 2048; + bool AllowWine = true; ///< Allow running Windows binaries under Wine on Linux agents + ConnectionMode Mode = ConnectionMode::Direct; + Encryption EncryptionMode = Encryption::None; + + /** Validate the configuration. Returns false if the configuration is invalid + * (e.g. Relay mode without AES encryption). */ + bool Validate() const; +}; + +const char* ToString(ConnectionMode Mode); +const char* ToString(Encryption Enc); + +bool FromString(ConnectionMode& OutMode, std::string_view Str); +bool FromString(Encryption& OutEnc, std::string_view Str); + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/hordeprovisioner.h b/src/zenhorde/include/zenhorde/hordeprovisioner.h new file mode 100644 index 000000000..4e2e63bbd --- /dev/null +++ b/src/zenhorde/include/zenhorde/hordeprovisioner.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace zen::horde { + +class HordeClient; + +/** Snapshot of the current provisioning state, returned by HordeProvisioner::GetStats(). */ +struct ProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected once pending requests complete + uint32_t ActiveCoreCount = 0; ///< Cores on machines that are currently running zenserver + uint32_t AgentsActive = 0; ///< Number of agents with a running remote process + uint32_t AgentsRequesting = 0; ///< Number of agents currently requesting a machine from Horde +}; + +/** Multi-agent lifecycle manager for Horde worker provisioning. + * + * Provisions remote compute workers by requesting machines from the Horde API, + * connecting via the Horde compute transport protocol, uploading the zenserver + * binary, and executing it remotely. Each provisioned machine runs zenserver + * in compute mode, which announces itself back to the orchestrator. + * + * Spawns one thread per agent. Each thread handles the full lifecycle: + * HTTP request -> TCP connect -> nonce handshake -> optional AES encryption -> + * channel setup -> binary upload -> remote execution -> poll until exit. + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class HordeProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Horde connection and pool configuration. + * @param BinariesPath Directory containing the zenserver binary to upload. + * @param WorkingDir Local directory for bundle staging and working files. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + HordeProvisioner(const HordeConfig& Config, + const std::filesystem::path& BinariesPath, + const std::filesystem::path& WorkingDir, + std::string_view OrchestratorEndpoint); + + /** Signals all agent threads to exit and joins them. */ + ~HordeProvisioner(); + + HordeProvisioner(const HordeProvisioner&) = delete; + HordeProvisioner& operator=(const HordeProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to HordeConfig::MaxCores. Spawns new agent threads if the + * estimated core count is below the target. Also joins any finished + * agent threads. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + ProvisioningStats GetStats() const; + + uint32_t GetActiveCoreCount() const { return m_ActiveCoreCount.load(); } + uint32_t GetAgentCount() const; + +private: + LoggerRef Log() { return m_Log; } + + struct AgentWrapper; + + void RequestAgent(); + void ThreadAgent(AgentWrapper& Wrapper); + + HordeConfig m_Config; + std::filesystem::path m_BinariesPath; + std::filesystem::path m_WorkingDir; + std::string m_OrchestratorEndpoint; + + std::unique_ptr m_HordeClient; + + std::mutex m_BundleLock; + std::vector> m_Bundles; ///< (locator, bundleDir) pairs + bool m_BundlesCreated = false; + + mutable std::mutex m_AgentsLock; + std::vector> m_Agents; + + std::atomic m_LastRequestFailTime{0}; + std::atomic m_TargetCoreCount{0}; + std::atomic m_EstimatedCoreCount{0}; + std::atomic m_ActiveCoreCount{0}; + std::atomic m_AgentsActive{0}; + std::atomic m_AgentsRequesting{0}; + std::atomic m_AskForAgents{true}; + + LoggerRef m_Log; + + static constexpr uint32_t EstimatedCoresPerAgent = 32; +}; + +} // namespace zen::horde diff --git a/src/zenhorde/include/zenhorde/zenhorde.h b/src/zenhorde/include/zenhorde/zenhorde.h new file mode 100644 index 000000000..35147ff75 --- /dev/null +++ b/src/zenhorde/include/zenhorde/zenhorde.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if !defined(ZEN_WITH_HORDE) +# define ZEN_WITH_HORDE 1 +#endif diff --git a/src/zenhorde/xmake.lua b/src/zenhorde/xmake.lua new file mode 100644 index 000000000..48d028e86 --- /dev/null +++ b/src/zenhorde/xmake.lua @@ -0,0 +1,22 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhorde') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zencompute", "zenutil") + add_packages("asio", "json11") + + if is_plat("windows") then + add_syslinks("Ws2_32", "Bcrypt") + end + + if is_plat("linux") or is_plat("macosx") then + add_packages("openssl") + end + + if is_os("macosx") then + add_cxxflags("-Wno-deprecated-declarations") + end diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 62c080a7b..02cccc540 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -103,6 +103,7 @@ public: virtual bool IsLocalMachineRequest() const = 0; virtual std::string_view GetAuthorizationHeader() const = 0; + virtual std::string_view GetRemoteAddress() const { return {}; } /** Respond with payload diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index c4d9ee777..33f182df9 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -544,7 +544,8 @@ public: HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest); + bool IsLocalMachineRequest, + std::string RemoteAddress); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; @@ -552,6 +553,7 @@ public: virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -569,6 +571,7 @@ public: uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; bool m_IsLocalMachineRequest; + std::string m_RemoteAddress; std::unique_ptr m_Response; }; @@ -1238,9 +1241,15 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - bool IsLocalConnection = m_Socket->local_endpoint().address() == m_Socket->remote_endpoint().address(); + auto RemoteEndpoint = m_Socket->remote_endpoint(); + bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection); + HttpAsioServerRequest Request(m_RequestData, + *Service, + m_RequestData.Body(), + RequestNumber, + IsLocalConnection, + RemoteEndpoint.address().to_string()); ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); @@ -1725,12 +1734,14 @@ HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest) + bool IsLocalMachineRequest, + std::string RemoteAddress) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) , m_IsLocalMachineRequest(IsLocalMachineRequest) +, m_RemoteAddress(std::move(RemoteAddress)) { const int PrefixLength = Service.UriPrefixLength(); @@ -1808,6 +1819,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const return m_IsLocalMachineRequest; } +std::string_view +HttpAsioServerRequest::GetRemoteAddress() const +{ + return m_RemoteAddress; +} + std::string_view HttpAsioServerRequest::GetAuthorizationHeader() const { diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index a48f1d316..cf639c114 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -336,8 +336,9 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -353,11 +354,12 @@ public: HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -1901,6 +1903,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const } } +std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + std::string_view HttpSysServerRequest::GetAuthorizationHeader() const { diff --git a/src/zennomad/include/zennomad/nomadclient.h b/src/zennomad/include/zennomad/nomadclient.h new file mode 100644 index 000000000..0a3411ace --- /dev/null +++ b/src/zennomad/include/zennomad/nomadclient.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include + +namespace zen { +class HttpClient; +} + +namespace zen::nomad { + +/** Summary of a Nomad job returned by the API. */ +struct NomadJobInfo +{ + std::string Id; + std::string Status; ///< "pending", "running", "dead" + std::string StatusDescription; +}; + +/** Summary of a Nomad allocation returned by the API. */ +struct NomadAllocInfo +{ + std::string Id; + std::string ClientStatus; ///< "pending", "running", "complete", "failed" + std::string TaskState; ///< State of the task within the allocation +}; + +/** HTTP client for the Nomad REST API (v1). + * + * Handles job submission, status polling, and job termination. + * All calls are synchronous. Thread safety: individual methods are + * not thread-safe; callers must synchronize access. + */ +class NomadClient +{ +public: + explicit NomadClient(const NomadConfig& Config); + ~NomadClient(); + + NomadClient(const NomadClient&) = delete; + NomadClient& operator=(const NomadClient&) = delete; + + /** Initialize the underlying HTTP client. Must be called before other methods. */ + bool Initialize(); + + /** Build the Nomad job registration JSON for the given job ID and orchestrator endpoint. + * The JSON structure varies based on the configured driver and distribution mode. */ + std::string BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const; + + /** Submit a job via PUT /v1/jobs. On success, populates OutJob with the job info. */ + bool SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob); + + /** Get the status of a job via GET /v1/job/{jobId}. */ + bool GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob); + + /** Get allocations for a job via GET /v1/job/{jobId}/allocations. */ + bool GetAllocations(const std::string& JobId, std::vector& OutAllocs); + + /** Stop a job via DELETE /v1/job/{jobId}. */ + bool StopJob(const std::string& JobId); + + LoggerRef Log() { return m_Log; } + +private: + NomadConfig m_Config; + std::unique_ptr m_Http; + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadconfig.h b/src/zennomad/include/zennomad/nomadconfig.h new file mode 100644 index 000000000..92d2bbaca --- /dev/null +++ b/src/zennomad/include/zennomad/nomadconfig.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +namespace zen::nomad { + +/** Nomad task driver type. */ +enum class Driver +{ + RawExec, ///< Use Nomad raw_exec driver (direct process execution) + Docker, ///< Use Nomad Docker driver +}; + +/** How the zenserver binary is made available on Nomad clients. */ +enum class BinaryDistribution +{ + PreDeployed, ///< Binary is already present on Nomad client nodes + Artifact, ///< Download binary via Nomad artifact stanza +}; + +/** Configuration for Nomad worker provisioning. + * + * Specifies the Nomad server URL, authentication, resource limits, and + * job configuration. Used by NomadClient and NomadProvisioner. + */ +struct NomadConfig +{ + bool Enabled = false; ///< Whether Nomad provisioning is active + std::string ServerUrl; ///< Nomad HTTP API URL (e.g. "http://localhost:4646") + std::string AclToken; ///< Nomad ACL token (sent as X-Nomad-Token header) + std::string Datacenter = "dc1"; ///< Target datacenter + std::string Namespace = "default"; ///< Nomad namespace + std::string Region; ///< Nomad region (empty = server default) + + Driver TaskDriver = Driver::RawExec; ///< Task driver for job execution + BinaryDistribution BinDistribution = BinaryDistribution::PreDeployed; ///< How to distribute the zenserver binary + + std::string BinaryPath; ///< Path to zenserver on Nomad clients (PreDeployed mode) + std::string ArtifactSource; ///< URL to download zenserver binary (Artifact mode) + std::string DockerImage; ///< Docker image name (Docker driver mode) + + int MaxJobs = 64; ///< Maximum concurrent Nomad jobs + int CpuMhz = 1000; ///< CPU MHz allocated per task + int MemoryMb = 2048; ///< Memory MB allocated per task + int CoresPerJob = 32; ///< Estimated cores per job (for scaling calculations) + int MaxCores = 2048; ///< Maximum total cores to provision + + std::string JobPrefix = "zenserver-worker"; ///< Prefix for generated Nomad job IDs + + /** Validate the configuration. Returns false if required fields are missing + * or incompatible options are set. */ + bool Validate() const; +}; + +const char* ToString(Driver D); +const char* ToString(BinaryDistribution Dist); + +bool FromString(Driver& OutDriver, std::string_view Str); +bool FromString(BinaryDistribution& OutDist, std::string_view Str); + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprocess.h b/src/zennomad/include/zennomad/nomadprocess.h new file mode 100644 index 000000000..a66c2ce41 --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprocess.h @@ -0,0 +1,78 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include + +namespace zen::nomad { + +struct NomadJobInfo; +struct NomadAllocInfo; + +/** Manages a Nomad agent process running in dev mode for testing. + * + * Spawns `nomad agent -dev` and polls the HTTP API until the agent + * is ready. On destruction or via StopNomadAgent(), the agent + * process is killed. + */ +class NomadProcess +{ +public: + NomadProcess(); + ~NomadProcess(); + + NomadProcess(const NomadProcess&) = delete; + NomadProcess& operator=(const NomadProcess&) = delete; + + /** Spawn a Nomad dev agent and block until the leader endpoint responds (10 s timeout). */ + void SpawnNomadAgent(); + + /** Kill the Nomad agent process. */ + void StopNomadAgent(); + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +/** Lightweight HTTP wrapper around the Nomad v1 REST API for use in tests. + * + * Unlike the production NomadClient (which requires a NomadConfig and + * supports all driver/distribution modes), this client exposes a simpler + * interface geared towards test scenarios. + */ +class NomadTestClient +{ +public: + explicit NomadTestClient(std::string_view BaseUri); + ~NomadTestClient(); + + NomadTestClient(const NomadTestClient&) = delete; + NomadTestClient& operator=(const NomadTestClient&) = delete; + + /** Submit a raw_exec batch job. + * Returns the parsed job info on success; Id will be empty on failure. */ + NomadJobInfo SubmitJob(std::string_view JobId, std::string_view Command, const std::vector& Args); + + /** Query the status of an existing job. */ + NomadJobInfo GetJobStatus(std::string_view JobId); + + /** Stop (deregister) a running job. */ + void StopJob(std::string_view JobId); + + /** Get allocations for a job. */ + std::vector GetAllocations(std::string_view JobId); + + /** List all jobs, optionally filtered by prefix. */ + std::vector ListJobs(std::string_view Prefix = ""); + +private: + HttpClient m_HttpClient; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/nomadprovisioner.h b/src/zennomad/include/zennomad/nomadprovisioner.h new file mode 100644 index 000000000..750693b3f --- /dev/null +++ b/src/zennomad/include/zennomad/nomadprovisioner.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace zen::nomad { + +class NomadClient; + +/** Snapshot of the current Nomad provisioning state, returned by NomadProvisioner::GetStats(). */ +struct NomadProvisioningStats +{ + uint32_t TargetCoreCount = 0; ///< Requested number of cores (clamped to MaxCores) + uint32_t EstimatedCoreCount = 0; ///< Cores expected from submitted jobs + uint32_t ActiveJobCount = 0; ///< Number of currently tracked Nomad jobs + uint32_t RunningJobCount = 0; ///< Number of jobs in "running" status +}; + +/** Job lifecycle manager for Nomad worker provisioning. + * + * Provisions remote compute workers by submitting batch jobs to a Nomad + * cluster via the REST API. Each job runs zenserver in compute mode, which + * announces itself back to the orchestrator. + * + * Uses a single management thread that periodically: + * 1. Submits new jobs when estimated cores < target cores + * 2. Polls existing jobs for status changes + * 3. Cleans up dead/failed jobs and adjusts counters + * + * Thread safety: SetTargetCoreCount and GetStats may be called from any thread. + */ +class NomadProvisioner +{ +public: + /** Construct a provisioner. + * @param Config Nomad connection and job configuration. + * @param OrchestratorEndpoint URL of the orchestrator that remote workers announce to. */ + NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint); + + /** Signals the management thread to exit and stops all tracked jobs. */ + ~NomadProvisioner(); + + NomadProvisioner(const NomadProvisioner&) = delete; + NomadProvisioner& operator=(const NomadProvisioner&) = delete; + + /** Set the target number of cores to provision. + * Clamped to NomadConfig::MaxCores. The management thread will + * submit new jobs to approach this target. */ + void SetTargetCoreCount(uint32_t Count); + + /** Return a snapshot of the current provisioning counters. */ + NomadProvisioningStats GetStats() const; + +private: + LoggerRef Log() { return m_Log; } + + struct TrackedJob + { + std::string JobId; + std::string Status; ///< "pending", "running", "dead" + int Cores = 0; + }; + + void ManagementThread(); + void SubmitNewJobs(); + void PollExistingJobs(); + void CleanupDeadJobs(); + void StopAllJobs(); + + std::string GenerateJobId(); + + NomadConfig m_Config; + std::string m_OrchestratorEndpoint; + + std::unique_ptr m_Client; + + mutable std::mutex m_JobsLock; + std::vector m_Jobs; + std::atomic m_JobIndex{0}; + + std::atomic m_TargetCoreCount{0}; + std::atomic m_EstimatedCoreCount{0}; + std::atomic m_RunningJobCount{0}; + + std::thread m_Thread; + std::mutex m_WakeMutex; + std::condition_variable m_WakeCV; + std::atomic m_ShouldExit{false}; + + uint32_t m_ProcessId = 0; + + LoggerRef m_Log; +}; + +} // namespace zen::nomad diff --git a/src/zennomad/include/zennomad/zennomad.h b/src/zennomad/include/zennomad/zennomad.h new file mode 100644 index 000000000..09fb98dfe --- /dev/null +++ b/src/zennomad/include/zennomad/zennomad.h @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if !defined(ZEN_WITH_NOMAD) +# define ZEN_WITH_NOMAD 1 +#endif diff --git a/src/zennomad/nomadclient.cpp b/src/zennomad/nomadclient.cpp new file mode 100644 index 000000000..9edcde125 --- /dev/null +++ b/src/zennomad/nomadclient.cpp @@ -0,0 +1,366 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::nomad { + +namespace { + + HttpClient::KeyValueMap MakeNomadHeaders(const NomadConfig& Config) + { + HttpClient::KeyValueMap Headers; + if (!Config.AclToken.empty()) + { + Headers->emplace("X-Nomad-Token", Config.AclToken); + } + return Headers; + } + +} // namespace + +NomadClient::NomadClient(const NomadConfig& Config) : m_Config(Config), m_Log(zen::logging::Get("nomad.client")) +{ +} + +NomadClient::~NomadClient() = default; + +bool +NomadClient::Initialize() +{ + ZEN_TRACE_CPU("NomadClient::Initialize"); + + HttpClientSettings Settings; + Settings.LogCategory = "nomad.http"; + Settings.ConnectTimeout = std::chrono::milliseconds{10000}; + Settings.Timeout = std::chrono::milliseconds{60000}; + Settings.RetryCount = 1; + + // Ensure the base URL ends with a slash so path concatenation works correctly + std::string BaseUrl = m_Config.ServerUrl; + if (!BaseUrl.empty() && BaseUrl.back() != '/') + { + BaseUrl += '/'; + } + + m_Http = std::make_unique(BaseUrl, Settings); + + return true; +} + +std::string +NomadClient::BuildJobJson(const std::string& JobId, const std::string& OrchestratorEndpoint) const +{ + ZEN_TRACE_CPU("NomadClient::BuildJobJson"); + + // Build the task config based on driver and distribution mode + json11::Json::object TaskConfig; + + if (m_Config.TaskDriver == Driver::RawExec) + { + std::string Command; + if (m_Config.BinDistribution == BinaryDistribution::PreDeployed) + { + Command = m_Config.BinaryPath; + } + else + { + // Artifact mode: binary is downloaded to local/zenserver + Command = "local/zenserver"; + } + + TaskConfig["command"] = Command; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + else + { + // Docker driver + TaskConfig["image"] = m_Config.DockerImage; + + json11::Json::array Args; + Args.push_back("compute"); + Args.push_back("--http=asio"); + if (!OrchestratorEndpoint.empty()) + { + ExtendableStringBuilder<256> CoordArg; + CoordArg << "--coordinator-endpoint=" << OrchestratorEndpoint; + Args.push_back(std::string(CoordArg.ToView())); + } + { + ExtendableStringBuilder<128> IdArg; + IdArg << "--instance-id=nomad-" << JobId; + Args.push_back(std::string(IdArg.ToView())); + } + TaskConfig["args"] = Args; + } + + // Build resource stanza + json11::Json::object Resources; + Resources["CPU"] = m_Config.CpuMhz; + Resources["MemoryMB"] = m_Config.MemoryMb; + + // Build the task + json11::Json::object Task; + Task["Name"] = "zenserver"; + Task["Driver"] = (m_Config.TaskDriver == Driver::RawExec) ? "raw_exec" : "docker"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + // Add artifact stanza if using artifact distribution + if (m_Config.BinDistribution == BinaryDistribution::Artifact && !m_Config.ArtifactSource.empty()) + { + json11::Json::object Artifact; + Artifact["GetterSource"] = m_Config.ArtifactSource; + + json11::Json::array Artifacts; + Artifacts.push_back(Artifact); + Task["Artifacts"] = Artifacts; + } + + json11::Json::array Tasks; + Tasks.push_back(Task); + + // Build the task group + json11::Json::object Group; + Group["Name"] = "zenserver-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + // Build datacenters array + json11::Json::array Datacenters; + Datacenters.push_back(m_Config.Datacenter); + + // Build the job + json11::Json::object Job; + Job["ID"] = JobId; + Job["Name"] = JobId; + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + if (!m_Config.Namespace.empty() && m_Config.Namespace != "default") + { + Job["Namespace"] = m_Config.Namespace; + } + + if (!m_Config.Region.empty()) + { + Job["Region"] = m_Config.Region; + } + + // Wrap in the registration envelope + json11::Json::object Root; + Root["Job"] = Job; + + return json11::Json(Root).dump(); +} + +bool +NomadClient::SubmitJob(const std::string& JobJson, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::SubmitJob"); + + const IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{JobJson.data(), JobJson.size()}, ZenContentType::kJSON); + + const HttpClient::Response Response = m_Http->Put("v1/jobs", Payload, MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job submit failed: {}", Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast(Response.StatusCode); + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job submit failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON response from Nomad job submit: {}", Err); + return false; + } + + // The response contains EvalID; the job ID is what we submitted + OutJob.Id = Json["JobModifyIndex"].is_number() ? OutJob.Id : ""; + OutJob.Status = "pending"; + + ZEN_INFO("Nomad job submitted: eval_id={}", Json["EvalID"].string_value()); + + return true; +} + +bool +NomadClient::GetJobStatus(const std::string& JobId, NomadJobInfo& OutJob) +{ + ZEN_TRACE_CPU("NomadClient::GetJobStatus"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job status query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + const int StatusCode = static_cast(Response.StatusCode); + + if (StatusCode == 404) + { + ZEN_INFO("Nomad job '{}' not found", JobId); + OutJob.Status = "dead"; + return true; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job status query failed with HTTP/{}", StatusCode); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad job status response: {}", Err); + return false; + } + + OutJob.Id = Json["ID"].string_value(); + OutJob.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + OutJob.StatusDescription = Desc.string_value(); + } + + return true; +} + +bool +NomadClient::GetAllocations(const std::string& JobId, std::vector& OutAllocs) +{ + ZEN_TRACE_CPU("NomadClient::GetAllocations"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId << "/allocations"; + + const HttpClient::Response Response = m_Http->Get(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad allocation query failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad allocation query failed with HTTP/{}", static_cast(Response.StatusCode)); + return false; + } + + const std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("invalid JSON in Nomad allocation response: {}", Err); + return false; + } + + OutAllocs.clear(); + if (!Json.is_array()) + { + return true; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + // Extract task state if available + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + OutAllocs.push_back(std::move(Alloc)); + } + + return true; +} + +bool +NomadClient::StopJob(const std::string& JobId) +{ + ZEN_TRACE_CPU("NomadClient::StopJob"); + + ExtendableStringBuilder<128> Path; + Path << "v1/job/" << JobId; + + const HttpClient::Response Response = m_Http->Delete(Path.ToView(), MakeNomadHeaders(m_Config)); + + if (Response.Error) + { + ZEN_WARN("Nomad job stop failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return false; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("Nomad job stop failed with HTTP/{}", static_cast(Response.StatusCode)); + return false; + } + + ZEN_INFO("Nomad job '{}' stopped", JobId); + return true; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadconfig.cpp b/src/zennomad/nomadconfig.cpp new file mode 100644 index 000000000..d55b3da9a --- /dev/null +++ b/src/zennomad/nomadconfig.cpp @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +namespace zen::nomad { + +bool +NomadConfig::Validate() const +{ + if (ServerUrl.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::PreDeployed && BinaryPath.empty()) + { + return false; + } + + if (BinDistribution == BinaryDistribution::Artifact && ArtifactSource.empty()) + { + return false; + } + + if (TaskDriver == Driver::Docker && DockerImage.empty()) + { + return false; + } + + return true; +} + +const char* +ToString(Driver D) +{ + switch (D) + { + case Driver::RawExec: + return "raw_exec"; + case Driver::Docker: + return "docker"; + } + return "raw_exec"; +} + +const char* +ToString(BinaryDistribution Dist) +{ + switch (Dist) + { + case BinaryDistribution::PreDeployed: + return "predeployed"; + case BinaryDistribution::Artifact: + return "artifact"; + } + return "predeployed"; +} + +bool +FromString(Driver& OutDriver, std::string_view Str) +{ + if (Str == "raw_exec") + { + OutDriver = Driver::RawExec; + return true; + } + if (Str == "docker") + { + OutDriver = Driver::Docker; + return true; + } + return false; +} + +bool +FromString(BinaryDistribution& OutDist, std::string_view Str) +{ + if (Str == "predeployed") + { + OutDist = BinaryDistribution::PreDeployed; + return true; + } + if (Str == "artifact") + { + OutDist = BinaryDistribution::Artifact; + return true; + } + return false; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprocess.cpp b/src/zennomad/nomadprocess.cpp new file mode 100644 index 000000000..1ae968fb7 --- /dev/null +++ b/src/zennomad/nomadprocess.cpp @@ -0,0 +1,354 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +ZEN_THIRD_PARTY_INCLUDES_END + +#include + +namespace zen::nomad { + +////////////////////////////////////////////////////////////////////////// + +struct NomadProcess::Impl +{ + Impl(std::string_view BaseUri) : m_HttpClient(BaseUri) {} + ~Impl() = default; + + void SpawnNomadAgent() + { + ZEN_TRACE_CPU("SpawnNomadAgent"); + + if (m_ProcessHandle.IsValid()) + { + return; + } + + CreateProcOptions Options; + Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + + CreateProcResult Result = CreateProc("nomad" ZEN_EXE_SUFFIX_LITERAL, "nomad" ZEN_EXE_SUFFIX_LITERAL " agent -dev", Options); + + if (Result) + { + m_ProcessHandle.Initialize(Result); + + Stopwatch Timer; + + // Poll to check when the agent is ready + + do + { + Sleep(100); + HttpClient::Response Resp = m_HttpClient.Get("v1/status/leader"); + if (Resp) + { + ZEN_INFO("Nomad agent started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + + return; + } + } while (Timer.GetElapsedTimeMs() < 30000); + } + + // Report failure! + + ZEN_WARN("Nomad agent failed to start within timeout period"); + } + + void StopNomadAgent() + { + if (!m_ProcessHandle.IsValid()) + { + return; + } + + // This waits for the process to exit and also resets the handle + m_ProcessHandle.Kill(); + } + +private: + ProcessHandle m_ProcessHandle; + HttpClient m_HttpClient; +}; + +NomadProcess::NomadProcess() : m_Impl(std::make_unique("http://localhost:4646/")) +{ +} + +NomadProcess::~NomadProcess() +{ +} + +void +NomadProcess::SpawnNomadAgent() +{ + m_Impl->SpawnNomadAgent(); +} + +void +NomadProcess::StopNomadAgent() +{ + m_Impl->StopNomadAgent(); +} + +////////////////////////////////////////////////////////////////////////// + +NomadTestClient::NomadTestClient(std::string_view BaseUri) : m_HttpClient(BaseUri) +{ +} + +NomadTestClient::~NomadTestClient() +{ +} + +NomadJobInfo +NomadTestClient::SubmitJob(std::string_view JobId, std::string_view Command, const std::vector& Args) +{ + ZEN_TRACE_CPU("SubmitNomadJob"); + + NomadJobInfo Result; + + // Build the job JSON for a raw_exec batch job + json11::Json::object TaskConfig; + TaskConfig["command"] = std::string(Command); + + json11::Json::array JsonArgs; + for (const auto& Arg : Args) + { + JsonArgs.push_back(Arg); + } + TaskConfig["args"] = JsonArgs; + + json11::Json::object Resources; + Resources["CPU"] = 100; + Resources["MemoryMB"] = 64; + + json11::Json::object Task; + Task["Name"] = "test-task"; + Task["Driver"] = "raw_exec"; + Task["Config"] = TaskConfig; + Task["Resources"] = Resources; + + json11::Json::array Tasks; + Tasks.push_back(Task); + + json11::Json::object Group; + Group["Name"] = "test-group"; + Group["Count"] = 1; + Group["Tasks"] = Tasks; + + json11::Json::array Groups; + Groups.push_back(Group); + + json11::Json::array Datacenters; + Datacenters.push_back("dc1"); + + json11::Json::object Job; + Job["ID"] = std::string(JobId); + Job["Name"] = std::string(JobId); + Job["Type"] = "batch"; + Job["Datacenters"] = Datacenters; + Job["TaskGroups"] = Groups; + + json11::Json::object Root; + Root["Job"] = Job; + + std::string Body = json11::Json(Root).dump(); + + IoBuffer Payload = IoBufferBuilder::MakeFromMemory(MemoryView{Body.data(), Body.size()}, ZenContentType::kJSON); + + HttpClient::Response Response = + m_HttpClient.Put("v1/jobs", Payload, {{"Content-Type", "application/json"}, {"Accept", "application/json"}}); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: SubmitJob failed for '{}'", JobId); + return Result; + } + + std::string ResponseBody(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(ResponseBody, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in SubmitJob response: {}", Err); + return Result; + } + + Result.Id = std::string(JobId); + Result.Status = "pending"; + + ZEN_INFO("NomadTestClient: job '{}' submitted (eval_id={})", JobId, Json["EvalID"].string_value()); + + return Result; +} + +NomadJobInfo +NomadTestClient::GetJobStatus(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadJobStatus"); + + NomadJobInfo Result; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}", JobId)); + + if (Response.Error) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed for '{}': {}", JobId, Response.Error->ErrorMessage); + return Result; + } + + if (static_cast(Response.StatusCode) == 404) + { + Result.Status = "dead"; + return Result; + } + + if (!Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetJobStatus failed with HTTP/{}", static_cast(Response.StatusCode)); + return Result; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty()) + { + ZEN_WARN("NomadTestClient: invalid JSON in GetJobStatus response: {}", Err); + return Result; + } + + Result.Id = Json["ID"].string_value(); + Result.Status = Json["Status"].string_value(); + if (const json11::Json Desc = Json["StatusDescription"]; Desc.is_string()) + { + Result.StatusDescription = Desc.string_value(); + } + + return Result; +} + +void +NomadTestClient::StopJob(std::string_view JobId) +{ + ZEN_TRACE_CPU("StopNomadJob"); + + HttpClient::Response Response = m_HttpClient.Delete(fmt::format("v1/job/{}", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: StopJob failed for '{}'", JobId); + return; + } + + ZEN_INFO("NomadTestClient: job '{}' stopped", JobId); +} + +std::vector +NomadTestClient::GetAllocations(std::string_view JobId) +{ + ZEN_TRACE_CPU("GetNomadAllocations"); + + std::vector Allocs; + + HttpClient::Response Response = m_HttpClient.Get(fmt::format("v1/job/{}/allocations", JobId)); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: GetAllocations failed for '{}'", JobId); + return Allocs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Allocs; + } + + for (const json11::Json& AllocVal : Json.array_items()) + { + NomadAllocInfo Alloc; + Alloc.Id = AllocVal["ID"].string_value(); + Alloc.ClientStatus = AllocVal["ClientStatus"].string_value(); + + if (const json11::Json TaskStates = AllocVal["TaskStates"]; TaskStates.is_object()) + { + for (const auto& [TaskName, TaskState] : TaskStates.object_items()) + { + if (TaskState["State"].is_string()) + { + Alloc.TaskState = TaskState["State"].string_value(); + } + } + } + + Allocs.push_back(std::move(Alloc)); + } + + return Allocs; +} + +std::vector +NomadTestClient::ListJobs(std::string_view Prefix) +{ + ZEN_TRACE_CPU("ListNomadJobs"); + + std::vector Jobs; + + std::string Url = "v1/jobs"; + if (!Prefix.empty()) + { + Url = fmt::format("v1/jobs?prefix={}", Prefix); + } + + HttpClient::Response Response = m_HttpClient.Get(Url); + + if (!Response || !Response.IsSuccess()) + { + ZEN_WARN("NomadTestClient: ListJobs failed"); + return Jobs; + } + + std::string Body(Response.AsText()); + std::string Err; + const json11::Json Json = json11::Json::parse(Body, Err); + + if (!Err.empty() || !Json.is_array()) + { + return Jobs; + } + + for (const json11::Json& JobVal : Json.array_items()) + { + NomadJobInfo Job; + Job.Id = JobVal["ID"].string_value(); + Job.Status = JobVal["Status"].string_value(); + if (const json11::Json Desc = JobVal["StatusDescription"]; Desc.is_string()) + { + Job.StatusDescription = Desc.string_value(); + } + Jobs.push_back(std::move(Job)); + } + + return Jobs; +} + +} // namespace zen::nomad diff --git a/src/zennomad/nomadprovisioner.cpp b/src/zennomad/nomadprovisioner.cpp new file mode 100644 index 000000000..3fe9c0ac3 --- /dev/null +++ b/src/zennomad/nomadprovisioner.cpp @@ -0,0 +1,264 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace zen::nomad { + +NomadProvisioner::NomadProvisioner(const NomadConfig& Config, std::string_view OrchestratorEndpoint) +: m_Config(Config) +, m_OrchestratorEndpoint(OrchestratorEndpoint) +, m_ProcessId(static_cast(zen::GetCurrentProcessId())) +, m_Log(zen::logging::Get("nomad.provisioner")) +{ + ZEN_DEBUG("initializing provisioner (server: {}, driver: {}, max_cores: {}, cores_per_job: {}, max_jobs: {})", + m_Config.ServerUrl, + ToString(m_Config.TaskDriver), + m_Config.MaxCores, + m_Config.CoresPerJob, + m_Config.MaxJobs); + + m_Client = std::make_unique(m_Config); + if (!m_Client->Initialize()) + { + ZEN_ERROR("failed to initialize Nomad HTTP client"); + return; + } + + ZEN_DEBUG("Nomad HTTP client initialized, starting management thread"); + + m_Thread = std::thread([this] { ManagementThread(); }); +} + +NomadProvisioner::~NomadProvisioner() +{ + ZEN_DEBUG("provisioner shutting down"); + + m_ShouldExit.store(true); + m_WakeCV.notify_all(); + + if (m_Thread.joinable()) + { + m_Thread.join(); + } + + StopAllJobs(); + + ZEN_DEBUG("provisioner shutdown complete"); +} + +void +NomadProvisioner::SetTargetCoreCount(uint32_t Count) +{ + const uint32_t Clamped = std::min(Count, static_cast(m_Config.MaxCores)); + const uint32_t Previous = m_TargetCoreCount.exchange(Clamped); + + if (Clamped != Previous) + { + ZEN_DEBUG("target core count changed: {} -> {}", Previous, Clamped); + } + + m_WakeCV.notify_all(); +} + +NomadProvisioningStats +NomadProvisioner::GetStats() const +{ + NomadProvisioningStats Stats; + Stats.TargetCoreCount = m_TargetCoreCount.load(); + Stats.EstimatedCoreCount = m_EstimatedCoreCount.load(); + Stats.RunningJobCount = m_RunningJobCount.load(); + + { + std::lock_guard Lock(m_JobsLock); + Stats.ActiveJobCount = static_cast(m_Jobs.size()); + } + + return Stats; +} + +std::string +NomadProvisioner::GenerateJobId() +{ + const uint32_t Index = m_JobIndex.fetch_add(1); + + ExtendableStringBuilder<128> Builder; + Builder << m_Config.JobPrefix << "-" << m_ProcessId << "-" << Index; + return std::string(Builder.ToView()); +} + +void +NomadProvisioner::ManagementThread() +{ + ZEN_TRACE_CPU("Nomad_Mgmt"); + zen::SetCurrentThreadName("nomad_mgmt"); + + ZEN_INFO("Nomad management thread started"); + + while (!m_ShouldExit.load()) + { + ZEN_DEBUG("management cycle: target={} estimated={} running={} active={}", + m_TargetCoreCount.load(), + m_EstimatedCoreCount.load(), + m_RunningJobCount.load(), + [this] { + std::lock_guard Lock(m_JobsLock); + return m_Jobs.size(); + }()); + + SubmitNewJobs(); + PollExistingJobs(); + CleanupDeadJobs(); + + // Wait up to 5 seconds or until woken + std::unique_lock Lock(m_WakeMutex); + m_WakeCV.wait_for(Lock, std::chrono::seconds(5), [this] { return m_ShouldExit.load(); }); + } + + ZEN_INFO("Nomad management thread exiting"); +} + +void +NomadProvisioner::SubmitNewJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::SubmitNewJobs"); + + const uint32_t CoresPerJob = static_cast(m_Config.CoresPerJob); + + while (m_EstimatedCoreCount.load() < m_TargetCoreCount.load()) + { + { + std::lock_guard Lock(m_JobsLock); + if (static_cast(m_Jobs.size()) >= m_Config.MaxJobs) + { + ZEN_INFO("Nomad max jobs limit reached ({})", m_Config.MaxJobs); + break; + } + } + + if (m_ShouldExit.load()) + { + break; + } + + const std::string JobId = GenerateJobId(); + + ZEN_DEBUG("submitting job '{}' (estimated: {}, target: {})", JobId, m_EstimatedCoreCount.load(), m_TargetCoreCount.load()); + + const std::string JobJson = m_Client->BuildJobJson(JobId, m_OrchestratorEndpoint); + + NomadJobInfo JobInfo; + JobInfo.Id = JobId; + + if (!m_Client->SubmitJob(JobJson, JobInfo)) + { + ZEN_WARN("failed to submit Nomad job '{}'", JobId); + break; + } + + TrackedJob Tracked; + Tracked.JobId = JobId; + Tracked.Status = "pending"; + Tracked.Cores = static_cast(CoresPerJob); + + { + std::lock_guard Lock(m_JobsLock); + m_Jobs.push_back(std::move(Tracked)); + } + + m_EstimatedCoreCount.fetch_add(CoresPerJob); + + ZEN_INFO("Nomad job '{}' submitted (estimated cores: {})", JobId, m_EstimatedCoreCount.load()); + } +} + +void +NomadProvisioner::PollExistingJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::PollExistingJobs"); + + std::lock_guard Lock(m_JobsLock); + + for (auto& Job : m_Jobs) + { + if (m_ShouldExit.load()) + { + break; + } + + NomadJobInfo Info; + if (!m_Client->GetJobStatus(Job.JobId, Info)) + { + ZEN_DEBUG("failed to poll status for job '{}'", Job.JobId); + continue; + } + + const std::string PrevStatus = Job.Status; + Job.Status = Info.Status; + + if (PrevStatus != Job.Status) + { + ZEN_INFO("Nomad job '{}' status changed: {} -> {}", Job.JobId, PrevStatus, Job.Status); + + if (Job.Status == "running" && PrevStatus != "running") + { + m_RunningJobCount.fetch_add(1); + } + else if (Job.Status != "running" && PrevStatus == "running") + { + m_RunningJobCount.fetch_sub(1); + } + } + } +} + +void +NomadProvisioner::CleanupDeadJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::CleanupDeadJobs"); + + std::lock_guard Lock(m_JobsLock); + + for (auto It = m_Jobs.begin(); It != m_Jobs.end();) + { + if (It->Status == "dead") + { + ZEN_INFO("Nomad job '{}' is dead, removing from tracked jobs", It->JobId); + m_EstimatedCoreCount.fetch_sub(static_cast(It->Cores)); + It = m_Jobs.erase(It); + } + else + { + ++It; + } + } +} + +void +NomadProvisioner::StopAllJobs() +{ + ZEN_TRACE_CPU("NomadProvisioner::StopAllJobs"); + + std::lock_guard Lock(m_JobsLock); + + for (const auto& Job : m_Jobs) + { + ZEN_INFO("stopping Nomad job '{}' during shutdown", Job.JobId); + m_Client->StopJob(Job.JobId); + } + + m_Jobs.clear(); + m_EstimatedCoreCount.store(0); + m_RunningJobCount.store(0); +} + +} // namespace zen::nomad diff --git a/src/zennomad/xmake.lua b/src/zennomad/xmake.lua new file mode 100644 index 000000000..ef1a8b201 --- /dev/null +++ b/src/zennomad/xmake.lua @@ -0,0 +1,10 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zennomad') + set_kind("static") + set_group("libs") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenhttp", "zenutil") + add_packages("json11") diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index f4b4d592b..43a4937f0 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -8186,7 +8186,7 @@ TEST_CASE("buildstorageoperations.partial.block.download" * doctest::skip(true)) Headers); REQUIRE(GetBlobRangesResponse.IsSuccess()); - MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); + [[maybe_unused]] MemoryView RangesMemoryView = GetBlobRangesResponse.ResponsePayload.GetView(); std::vector> PayloadRanges = GetBlobRangesResponse.GetRanges(Ranges); if (PayloadRanges.empty()) diff --git a/src/zenremotestore/chunking/chunkingcache.cpp b/src/zenremotestore/chunking/chunkingcache.cpp index f4e1c7837..e9b783a00 100644 --- a/src/zenremotestore/chunking/chunkingcache.cpp +++ b/src/zenremotestore/chunking/chunkingcache.cpp @@ -75,13 +75,13 @@ public: { Lock.ReleaseNow(); RwLock::ExclusiveLockScope EditLock(m_Lock); - if (auto RemoveIt = m_PathHashToEntry.find(PathHash); It != m_PathHashToEntry.end()) + if (auto RemoveIt = m_PathHashToEntry.find(PathHash); RemoveIt != m_PathHashToEntry.end()) { - CachedEntry& DeleteEntry = m_Entries[It->second]; + CachedEntry& DeleteEntry = m_Entries[RemoveIt->second]; DeleteEntry.Chunked = {}; DeleteEntry.ModificationTick = 0; - m_FreeEntryIndexes.push_back(It->second); - m_PathHashToEntry.erase(It); + m_FreeEntryIndexes.push_back(RemoveIt->second); + m_PathHashToEntry.erase(RemoveIt); } } } diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp new file mode 100644 index 000000000..c90ac5d8b --- /dev/null +++ b/src/zenserver-test/compute-tests.cpp @@ -0,0 +1,1700 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include "zenserver-test.h" + +# include + +namespace zen::tests::compute { + +using namespace std::literals; + +// BuildSystemVersion and function version GUIDs matching zentest-appstub +static constexpr std::string_view kBuildSystemVersion = "17fe280d-ccd8-4be8-a9d1-89c944a70969"; +static constexpr std::string_view kRot13Version = "13131313-1313-1313-1313-131313131313"; +static constexpr std::string_view kSleepVersion = "88888888-8888-8888-8888-888888888888"; + +// In-memory implementation of ChunkResolver for test use. +// Stores compressed data keyed by decompressed content hash. +class InMemoryChunkResolver : public ChunkResolver +{ +public: + IoBuffer FindChunkByCid(const IoHash& DecompressedId) override + { + auto It = m_Chunks.find(DecompressedId); + if (It != m_Chunks.end()) + { + return It->second; + } + return {}; + } + + void AddChunk(const IoHash& DecompressedId, IoBuffer Data) { m_Chunks[DecompressedId] = std::move(Data); } + +private: + std::unordered_map m_Chunks; +}; + +// Read, compress, and register zentest-appstub as a worker. +// Returns the WorkerId (hash of the worker package object). +static IoHash +RegisterWorker(HttpClient& Client, ZenServerEnvironment& Env) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response RegisterResp = Client.Post(WorkerUrl, std::move(WorkerPackage)); + REQUIRE_MESSAGE(RegisterResp, + fmt::format("Worker registration failed: status={}, body={}", int(RegisterResp.StatusCode), RegisterResp.ToText())); + + return WorkerId; +} + +// Build a Rot13 action CbPackage for the given input string. +static CbPackage +BuildRot13ActionPackage(std::string_view Input) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbPackage. The worker sleeps for SleepTimeMs before returning its input. +static CbPackage +BuildSleepActionPackage(std::string_view Input, uint64_t SleepTimeMs) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + CbPackage ActionPackage; + ActionPackage.SetObject(ActionWriter.Save()); + ActionPackage.AddAttachment(InputAttachment); + + return ActionPackage; +} + +// Build a Sleep action CbObject and populate the chunk resolver with the input attachment. +static CbObject +BuildSleepActionForSession(std::string_view Input, uint64_t SleepTimeMs, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Sleep"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kSleepVersion); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + ActionWriter.BeginObject("Constants"sv); + ActionWriter << "SleepTimeMs"sv << SleepTimeMs; + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +static HttpClient::Response +PollForResult(HttpClient& Client, const std::string& ResultUrl, uint64_t TimeoutMs = 30'000) +{ + HttpClient::Response Resp; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + Resp = Client.Get(ResultUrl); + + if (Resp.StatusCode == HttpResponseCode::OK) + { + break; + } + + Sleep(100); + } + + return Resp; +} + +static bool +PollForLsnInCompleted(HttpClient& Client, const std::string& CompletedUrl, int Lsn, uint64_t TimeoutMs = 30'000) +{ + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < TimeoutMs) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + for (auto& Item : Resp.AsObject()["completed"sv]) + { + if (Item.AsInt32() == Lsn) + { + return true; + } + } + } + + Sleep(100); + } + + return false; +} + +static std::string +GetRot13Output(const CbPackage& ResultPackage) +{ + CbObject ResultObj = ResultPackage.GetObject(); + + IoHash OutputHash; + CbFieldView ValuesField = ResultObj["Values"sv]; + + if (CbFieldViewIterator It = begin(ValuesField); It.HasValue()) + { + OutputHash = (*It).AsObjectView()["RawHash"sv].AsHash(); + } + + REQUIRE_MESSAGE(OutputHash != IoHash::Zero, "Expected non-zero output hash in result Values array"); + + const CbAttachment* OutputAttachment = ResultPackage.FindAttachment(OutputHash); + REQUIRE_MESSAGE(OutputAttachment != nullptr, "Output attachment not found in result package"); + + CompressedBuffer OutputCompressed = OutputAttachment->AsCompressedBinary(); + SharedBuffer OutputData = OutputCompressed.Decompress(); + + return std::string(static_cast(OutputData.GetData()), OutputData.GetSize()); +} + +// Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response. +class MockOrchestratorService : public HttpService +{ +public: + MockOrchestratorService() + { + // Initialize with empty worker list + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + Cbo.EndArray(); + m_WorkerList = Cbo.Save(); + } + + const char* BaseUri() const override { return "/orch/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + if (Request.RequestVerb() == HttpVerb::kGet && Request.RelativeUri() == "agents"sv) + { + RwLock::SharedLockScope Lock(m_Lock); + Request.WriteResponse(HttpResponseCode::OK, m_WorkerList); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + } + + void SetWorkerList(CbObject WorkerList) + { + RwLock::ExclusiveLockScope Lock(m_Lock); + m_WorkerList = std::move(WorkerList); + } + +private: + RwLock m_Lock; + CbObject m_WorkerList; +}; + +// Manages in-process ASIO HTTP server lifecycle for mock orchestrator. +struct MockOrchestratorFixture +{ + MockOrchestratorService Service; + ScopedTemporaryDirectory TmpDir; + Ref Server; + std::thread ServerThread; + uint16_t Port = 0; + + MockOrchestratorFixture() + { + HttpServerConfig Config; + Config.ServerClass = "asio"; + Config.ForceLoopback = true; + Server = CreateHttpServer(Config); + Server->RegisterService(Service); + Port = static_cast(Server->Initialize(TestEnv.GetNewPortNumber(), TmpDir.Path())); + ZEN_ASSERT(Port != 0); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~MockOrchestratorFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + std::string GetEndpoint() const { return fmt::format("http://localhost:{}", Port); } +}; + +// Build the CbObject response for /orch/agents matching the format UpdateCoordinatorState expects. +static CbObject +BuildAgentListResponse(std::initializer_list> Workers) +{ + CbObjectWriter Cbo; + Cbo.BeginArray("workers"sv); + for (const auto& [Id, Uri] : Workers) + { + Cbo.BeginObject(); + Cbo << "id"sv << Id; + Cbo << "uri"sv << Uri; + Cbo << "hostname"sv + << "localhost"sv; + Cbo << "reachable"sv << true; + Cbo << "dt"sv << uint64_t(0); + Cbo.EndObject(); + } + Cbo.EndArray(); + return Cbo.Save(); +} + +// Build the worker CbPackage for zentest-appstub AND populate the chunk resolver. +// This is the same logic as RegisterWorker() but returns the package instead of POSTing it. +static CbPackage +BuildWorkerPackage(ZenServerEnvironment& Env, InMemoryChunkResolver& Resolver) +{ + std::filesystem::path AppStubPath = Env.ProgramBaseDir() / ("zentest-appstub" ZEN_EXE_SUFFIX_LITERAL); + + FileContents AppStubData = zen::ReadFile(AppStubPath); + REQUIRE_MESSAGE(!AppStubData.ErrorCode, fmt::format("Failed to read '{}': {}", AppStubPath.string(), AppStubData.ErrorCode.message())); + + IoBuffer AppStubBuffer = AppStubData.Flatten(); + + CompressedBuffer AppStubCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(AppStubBuffer.GetData(), AppStubBuffer.Size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash AppStubRawHash = AppStubCompressed.DecodeRawHash(); + const uint64_t AppStubRawSize = AppStubBuffer.Size(); + + // Store compressed data in chunk resolver for when the remote runner needs it + Resolver.AddChunk(AppStubRawHash, AppStubCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment AppStubAttachment(std::move(AppStubCompressed), AppStubRawHash); + + CbObjectWriter WorkerWriter; + WorkerWriter << "buildsystem_version"sv << Guid::FromString(kBuildSystemVersion); + WorkerWriter << "path"sv + << "zentest-appstub"sv; + + WorkerWriter.BeginArray("executables"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "zentest-appstub"sv; + WorkerWriter.AddAttachment("hash"sv, AppStubAttachment); + WorkerWriter << "size"sv << AppStubRawSize; + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + WorkerWriter.BeginArray("functions"sv); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Rot13"sv; + WorkerWriter << "version"sv << Guid::FromString(kRot13Version); + WorkerWriter.EndObject(); + WorkerWriter.BeginObject(); + WorkerWriter << "name"sv + << "Sleep"sv; + WorkerWriter << "version"sv << Guid::FromString(kSleepVersion); + WorkerWriter.EndObject(); + WorkerWriter.EndArray(); + + CbPackage WorkerPackage; + WorkerPackage.SetObject(WorkerWriter.Save()); + WorkerPackage.AddAttachment(AppStubAttachment); + + return WorkerPackage; +} + +// Build a Rot13 action CbObject (not CbPackage) and populate the chunk resolver with the input attachment. +static CbObject +BuildRot13ActionForSession(std::string_view Input, InMemoryChunkResolver& Resolver) +{ + CompressedBuffer InputCompressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Input.data(), Input.size()), + OodleCompressor::Selkie, + OodleCompressionLevel::HyperFast4); + + const IoHash InputRawHash = InputCompressed.DecodeRawHash(); + const uint64_t InputRawSize = Input.size(); + + // Store compressed data in chunk resolver + Resolver.AddChunk(InputRawHash, InputCompressed.GetCompressed().Flatten().AsIoBuffer()); + + CbAttachment InputAttachment(std::move(InputCompressed), InputRawHash); + + CbObjectWriter ActionWriter; + ActionWriter << "Function"sv + << "Rot13"sv; + ActionWriter << "FunctionVersion"sv << Guid::FromString(kRot13Version); + ActionWriter << "BuildSystemVersion"sv << Guid::FromString(kBuildSystemVersion); + ActionWriter.BeginObject("Inputs"sv); + ActionWriter.BeginObject("Source"sv); + ActionWriter.AddAttachment("RawHash"sv, InputAttachment); + ActionWriter << "RawSize"sv << InputRawSize; + ActionWriter.EndObject(); + ActionWriter.EndObject(); + + return ActionWriter.Save(); +} + +TEST_SUITE_BEGIN("server.function"); + +TEST_CASE("function.rot13") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Submit action via legacy /jobs/{worker} endpoint + const std::string JobUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); + + // Poll for result via legacy /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/jobs/{}", Lsn); + HttpClient::Response ResultResp = PollForResult(Client, ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Job did not complete in time. Last status: {}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Action failed (empty result package)\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.workers") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + // Before registration, GET /workers should return an empty list + HttpClient::Response EmptyListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(EmptyListResp, "Failed to list workers before registration"); + CHECK_EQ(EmptyListResp.AsObject()["workers"sv].AsArrayView().Num(), 0); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // GET /workers — the registered worker should appear in the listing + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers after registration"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerId) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, fmt::format("Worker {} not found in worker listing", WorkerId.ToHexString())); + + // GET /workers/{worker} — descriptor should match what was registered + const std::string WorkerUrl = fmt::format("/workers/{}", WorkerId.ToHexString()); + HttpClient::Response DescResp = Client.Get(WorkerUrl); + REQUIRE_MESSAGE(DescResp, fmt::format("Failed to get worker descriptor: status={}", int(DescResp.StatusCode))); + + CbObject Desc = DescResp.AsObject(); + CHECK_EQ(Desc["buildsystem_version"sv].AsUuid(), Guid::FromString(kBuildSystemVersion)); + CHECK_EQ(Desc["path"sv].AsString(), "zentest-appstub"sv); + + bool Rot13Found = false; + bool SleepFound = false; + for (auto& Item : Desc["functions"sv]) + { + std::string_view Name = Item.AsObjectView()["name"sv].AsString(); + if (Name == "Rot13"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kRot13Version)); + Rot13Found = true; + } + else if (Name == "Sleep"sv) + { + CHECK_EQ(Item.AsObjectView()["version"sv].AsUuid(), Guid::FromString(kSleepVersion)); + SleepFound = true; + } + } + + CHECK_MESSAGE(Rot13Found, "Rot13 function not found in worker descriptor"); + CHECK_MESSAGE(SleepFound, "Sleep function not found in worker descriptor"); + + // GET /workers/{unknown} — should return 404 + const std::string UnknownUrl = fmt::format("/workers/{}", IoHash::Zero.ToHexString()); + HttpClient::Response NotFoundResp = Client.Get(UnknownUrl); + CHECK_EQ(NotFoundResp.StatusCode, HttpResponseCode::NotFound); +} + +TEST_CASE("function.queues.lifecycle") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, fmt::format("Queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Verify the queue appears in the listing + HttpClient::Response ListResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list queues"); + + bool QueueFound = false; + for (auto& Item : ListResp.AsObject()["queues"sv]) + { + if (Item.AsObjectView()["queue_id"sv].AsInt32() == QueueId) + { + QueueFound = true; + break; + } + } + + REQUIRE_MESSAGE(QueueFound, fmt::format("Queue {} not found in queue listing", QueueId)); + + // Submit action via queue-scoped endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from queue job submission"); + + // Poll for completion via queue-scoped /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list within timeout\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Retrieve result via queue-scoped /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueId, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE( + ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result: status={}\nServer log:\n{}", int(ResultResp.StatusCode), Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify queue status reflects completion + const std::string StatusUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(StatusUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["failed_count"sv].AsInt32(), 0); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "active"); +} + +TEST_CASE("function.queues.cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a job + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + // Cancel the queue + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // Verify queue status shows cancelled + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); +} + +TEST_CASE("function.queues.remote") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue — response includes both an integer queue_id and an OID queue_token + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + CbObject CreateObj = CreateResp.AsObject(); + const std::string QueueToken = std::string(CreateObj["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // All subsequent requests use the opaque token in place of the integer queue id + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildRot13ActionPackage("Hello World"sv)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Remote queue job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from remote queue job submission"); + + // Poll for completion via the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list within timeout\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Retrieve result via the token-addressed /jobs/{lsn} endpoint + const std::string ResultUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, Lsn); + HttpClient::Response ResultResp = Client.Get(ResultUrl); + REQUIRE_MESSAGE(ResultResp.StatusCode == HttpResponseCode::OK, + fmt::format("Failed to retrieve result from remote queue: status={}\nServer log:\n{}", + int(ResultResp.StatusCode), + Instance.GetLogOutput())); + + // Verify result: Rot13("Hello World") == "Uryyb Jbeyq" + CbPackage ResultPackage = ResultResp.AsPackage(); + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); +} + +TEST_CASE("function.queues.cancel_running") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id from queue creation"); + + // Submit a Sleep job long enough that it will still be running when we cancel + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue, which should interrupt the running Sleep job + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the /completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after cancel\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one cancelled action + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.remote_cancel") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a remote queue to obtain an OID token for token-addressed cancellation + HttpClient::Response CreateResp = Client.Post("/queues/remote"sv); + REQUIRE_MESSAGE(CreateResp, + fmt::format("Remote queue creation failed: status={}, body={}", int(CreateResp.StatusCode), CreateResp.ToText())); + + const std::string QueueToken = std::string(CreateResp.AsObject()["queue_token"sv].AsString()); + REQUIRE_MESSAGE(!QueueToken.empty(), "Expected non-empty queue_token from remote queue creation"); + + // Submit a long-running Sleep job via the token-addressed endpoint + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueToken, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, + fmt::format("Sleep job submission failed: status={}, body={}", int(SubmitResp.StatusCode), SubmitResp.ToText())); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from Sleep job submission"); + + // Wait for the worker process to start executing before cancelling + Sleep(1'000); + + // Cancel the queue via its OID token + const std::string QueueUrl = fmt::format("/queues/{}", QueueToken); + HttpClient::Response CancelResp = Client.Delete(QueueUrl); + REQUIRE_MESSAGE(CancelResp.StatusCode == HttpResponseCode::NoContent, + fmt::format("Remote queue cancellation failed: status={}, body={}", int(CancelResp.StatusCode), CancelResp.ToText())); + + // The cancelled job should appear in the token-addressed /completed endpoint + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueToken); + REQUIRE_MESSAGE( + PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in remote queue completed list after cancel\nServer log:\n{}", Lsn, Instance.GetLogOutput())); + + // Verify the queue status reflects the cancellation + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get remote queue status after cancel"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "cancelled"); + CHECK_EQ(QueueStatus["cancelled_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); +} + +TEST_CASE("function.queues.drain") +{ + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + + // Submit a long-running job so we can verify it completes even after drain + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response Submit1 = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 2'000)); + REQUIRE_MESSAGE(Submit1, fmt::format("First job submission failed: status={}", int(Submit1.StatusCode))); + const int Lsn1 = Submit1.AsObject()["lsn"sv].AsInt32(); + + // Drain the queue + const std::string DrainUrl = fmt::format("/queues/{}/drain", QueueId); + HttpClient::Response DrainResp = Client.Post(DrainUrl); + REQUIRE_MESSAGE(DrainResp, fmt::format("Drain failed: status={}, body={}", int(DrainResp.StatusCode), DrainResp.ToText())); + CHECK_EQ(std::string(DrainResp.AsObject()["state"sv].AsString()), "draining"); + + // Second submission should be rejected with 424 + HttpClient::Response Submit2 = Client.Post(JobUrl, BuildRot13ActionPackage("Hello"sv)); + CHECK_EQ(Submit2.StatusCode, HttpResponseCode::FailedDependency); + CHECK_EQ(std::string(Submit2.AsObject()["error"sv].AsString()), "queue is draining"); + + // First job should still complete + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn1), + fmt::format("LSN {} did not complete after drain\nServer log:\n{}", Lsn1, Instance.GetLogOutput())); + + // Queue status should show draining + complete + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(std::string(QueueStatus["state"sv].AsString()), "draining"); + CHECK(QueueStatus["is_complete"sv].AsBool()); +} + +TEST_CASE("function.priority") +{ + // Spawn server with max-actions=1 to guarantee serialized action execution, + // which lets us deterministically verify that higher-priority pending jobs + // are scheduled before lower-priority ones. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1"); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue for all test jobs + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + // Submit a blocker Sleep job to occupy the single execution slot. + // Once the blocker is running, the scheduler must choose among the pending + // jobs by priority when the slot becomes free. + const std::string BlockerJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + HttpClient::Response BlockerResp = Client.Post(BlockerJobUrl, BuildSleepActionPackage("data"sv, 1'000)); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker job submission failed: status={}", int(BlockerResp.StatusCode))); + + // Submit 3 low-priority Rot13 jobs + const std::string LowJobUrl = fmt::format("/queues/{}/jobs/{}?priority=0", QueueId, WorkerId.ToHexString()); + + HttpClient::Response LowResp1 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low1"sv)); + REQUIRE_MESSAGE(LowResp1, "Low-priority job 1 submission failed"); + const int LsnLow1 = LowResp1.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp2 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low2"sv)); + REQUIRE_MESSAGE(LowResp2, "Low-priority job 2 submission failed"); + const int LsnLow2 = LowResp2.AsObject()["lsn"sv].AsInt32(); + + HttpClient::Response LowResp3 = Client.Post(LowJobUrl, BuildRot13ActionPackage("low3"sv)); + REQUIRE_MESSAGE(LowResp3, "Low-priority job 3 submission failed"); + const int LsnLow3 = LowResp3.AsObject()["lsn"sv].AsInt32(); + + // Submit 1 high-priority Rot13 job — should execute before the low-priority ones + const std::string HighJobUrl = fmt::format("/queues/{}/jobs/{}?priority=10", QueueId, WorkerId.ToHexString()); + HttpClient::Response HighResp = Client.Post(HighJobUrl, BuildRot13ActionPackage("high"sv)); + REQUIRE_MESSAGE(HighResp, "High-priority job submission failed"); + const int LsnHigh = HighResp.AsObject()["lsn"sv].AsInt32(); + + // Wait for all 4 priority-test jobs to appear in the queue's completed list. + // This avoids any snapshot-timing race: by the time we compare timestamps, all + // jobs have already finished and their history entries are stable. + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + + { + bool AllCompleted = false; + Stopwatch WaitTimer; + + while (!AllCompleted && WaitTimer.GetElapsedTimeMs() < 30'000) + { + HttpClient::Response Resp = Client.Get(CompletedUrl); + + if (Resp) + { + bool FoundHigh = false; + bool FoundLow1 = false; + bool FoundLow2 = false; + bool FoundLow3 = false; + + CbObject RespObj = Resp.AsObject(); + + for (auto& Item : RespObj["completed"sv]) + { + const int Lsn = Item.AsInt32(); + if (Lsn == LsnHigh) + { + FoundHigh = true; + } + else if (Lsn == LsnLow1) + { + FoundLow1 = true; + } + else if (Lsn == LsnLow2) + { + FoundLow2 = true; + } + else if (Lsn == LsnLow3) + { + FoundLow3 = true; + } + } + + AllCompleted = FoundHigh && FoundLow1 && FoundLow2 && FoundLow3; + } + + if (!AllCompleted) + { + Sleep(100); + } + } + + REQUIRE_MESSAGE( + AllCompleted, + fmt::format( + "Not all priority test jobs completed within timeout (lsnHigh={} lsnLow1={} lsnLow2={} lsnLow3={})\nServer log:\n{}", + LsnHigh, + LsnLow1, + LsnLow2, + LsnLow3, + Instance.GetLogOutput())); + } + + // Query the queue-scoped history to obtain the time_Completed timestamp for each + // job. The history endpoint records when each RunnerAction::State transition + // occurred, so time_Completed is the wall-clock tick at which the action finished. + // Using the queue-scoped endpoint avoids exposing history from other queues. + const std::string HistoryUrl = fmt::format("/queues/{}/history", QueueId); + HttpClient::Response HistoryResp = Client.Get(HistoryUrl); + REQUIRE_MESSAGE(HistoryResp, "Failed to query queue action history"); + + CbObject HistoryObj = HistoryResp.AsObject(); + + auto GetCompletedTimestamp = [&](int Lsn) -> uint64_t { + for (auto& Item : HistoryObj["history"sv]) + { + if (Item.AsObjectView()["lsn"sv].AsInt32() == Lsn) + { + return Item.AsObjectView()["time_Completed"sv].AsUInt64(); + } + } + return 0; + }; + + const uint64_t TimeHigh = GetCompletedTimestamp(LsnHigh); + const uint64_t TimeLow1 = GetCompletedTimestamp(LsnLow1); + const uint64_t TimeLow2 = GetCompletedTimestamp(LsnLow2); + const uint64_t TimeLow3 = GetCompletedTimestamp(LsnLow3); + + REQUIRE_MESSAGE(TimeHigh != 0, fmt::format("lsnHigh={} not found in action history", LsnHigh)); + REQUIRE_MESSAGE(TimeLow1 != 0, fmt::format("lsnLow1={} not found in action history", LsnLow1)); + REQUIRE_MESSAGE(TimeLow2 != 0, fmt::format("lsnLow2={} not found in action history", LsnLow2)); + REQUIRE_MESSAGE(TimeLow3 != 0, fmt::format("lsnLow3={} not found in action history", LsnLow3)); + + // The high-priority job must have completed strictly before every low-priority job + CHECK_MESSAGE(TimeHigh < TimeLow1, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow1={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow1, + TimeLow1)); + CHECK_MESSAGE(TimeHigh < TimeLow2, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow2={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow2, + TimeLow2)); + CHECK_MESSAGE(TimeHigh < TimeLow3, + fmt::format("Priority ordering violated: lsnHigh={} completed at t={} but lsnLow3={} completed at t={} (expected later)", + LsnHigh, + TimeHigh, + LsnLow3, + TimeLow3)); +} + +////////////////////////////////////////////////////////////////////////// +// Remote worker synchronization tests +// +// These tests exercise the orchestrator discovery path where new compute +// nodes appear over time and must receive previously registered workers +// via SyncWorkersToRunner(). + +TEST_CASE("function.remote.worker_sync_on_discovery") +{ + // Spawn real zenserver in compute mode + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t ServerPort = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(ServerPort != 0, Instance.GetLogOutput()); + + const std::string ServerUri = fmt::format("http://localhost:{}", ServerPort); + + // Start mock orchestrator with empty worker list + MockOrchestratorFixture MockOrch; + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session (stored locally, no runners yet) + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Update mock orchestrator to advertise the real server + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}})); + + // Wait for scheduler to discover the runner (~5s throttle + margin) + Sleep(7'000); + + // Submit Rot13 action via session + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.late_runner_discovery") +{ + // Spawn first server + ZenServerInstance Instance1(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance1.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port1 = Instance1.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port1 != 0, Instance1.GetLogOutput()); + + const std::string ServerUri1 = fmt::format("http://localhost:{}", Port1); + + // Start mock orchestrator advertising W1 + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}})); + + // Create session and register worker + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for W1 discovery + Sleep(7'000); + + // Baseline: submit Rot13 action and verify it completes on W1 + { + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance1.GetLogOutput())); + + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + } + + // Spawn second server + ZenServerInstance Instance2(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance2.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port2 = Instance2.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port2 != 0, Instance2.GetLogOutput()); + + const std::string ServerUri2 = fmt::format("http://localhost:{}", Port2); + + // Update mock orchestrator to include both W1 and W2 + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}})); + + // Wait for W2 discovery + Sleep(7'000); + + // Verify W2 received the worker by querying its /compute/workers endpoint directly + { + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port2); + HttpClient Client(ComputeBaseUri); + HttpClient::Response ListResp = Client.Get("/workers"sv); + REQUIRE_MESSAGE(ListResp, "Failed to list workers on W2"); + + bool WorkerFound = false; + for (auto& Item : ListResp.AsObject()["workers"sv]) + { + if (Item.AsHash() == WorkerPackage.GetObjectHash()) + { + WorkerFound = true; + break; + } + } + + REQUIRE_MESSAGE(WorkerFound, + fmt::format("Worker not found on W2 after discovery — SyncWorkersToRunner may have failed\nW2 log:\n{}", + Instance2.GetLogOutput())); + } + + // Submit another action and verify it completes (could run on either W1 or W2) + { + CbObject ActionObj = BuildRot13ActionForSession("Second Test"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Second action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Second action did not complete in time\nW1 log:\n{}\nW2 log:\n{}", + Instance1.GetLogOutput(), + Instance2.GetLogOutput())); + + // Rot13("Second Test") = "Frpbaq Grfg" + CHECK_EQ(GetRot13Output(ResultPackage), "Frpbaq Grfg"sv); + } + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_association") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit action to it + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Action enqueue to queue failed"); + + // Poll for result + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE( + ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete in time. Last status: {}\nServer log:\n{}", int(ResultCode), Instance.GetLogOutput())); + + REQUIRE_MESSAGE(bool(ResultPackage), fmt::format("Empty result package\nServer log:\n{}", Instance.GetLogOutput())); + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + + // Verify that a non-implicit remote queue was created on the compute node + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueFound = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueFound = true; + break; + } + } + + CHECK_MESSAGE(RemoteQueueFound, "Expected a non-implicit remote queue on the compute node"); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.queue_cancel_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + zen::compute::ComputeServiceSession::EnqueueResult EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Cancel the local queue — this should propagate to the remote + Session.CancelQueue(QueueId); + + // Poll for the action to complete (as cancelled) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + // Verify the local queue shows cancelled + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK(QueueStatus.State == zen::compute::ComputeServiceSession::QueueState::Cancelled); + + // Verify the remote queue on the compute node is also cancelled + HttpClient Client(Instance.GetBaseUri() + "/compute"); + + HttpClient::Response QueuesResp = Client.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server"); + + bool RemoteQueueCancelled = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled"; + break; + } + } + + CHECK_MESSAGE(RemoteQueueCancelled, "Expected the remote queue to be cancelled"); + + Session.Shutdown(); +} + +TEST_CASE("function.abandon_running_http") +{ + // Spawn a real zenserver to execute a long-running action, then abandon via HTTP endpoint + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Create a queue and submit a long-running Sleep job + HttpClient::Response CreateResp = Client.Post("/queues"sv); + REQUIRE_MESSAGE(CreateResp, "Queue creation failed"); + + const int QueueId = CreateResp.AsObject()["queue_id"sv].AsInt32(); + REQUIRE_MESSAGE(QueueId != 0, "Expected non-zero queue_id"); + + const std::string JobUrl = fmt::format("/queues/{}/jobs/{}", QueueId, WorkerId.ToHexString()); + HttpClient::Response SubmitResp = Client.Post(JobUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Sleep job submission failed: status={}", int(SubmitResp.StatusCode))); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN"); + + // Wait for the process to start running + Sleep(1'000); + + // Verify the ready endpoint returns OK before abandon + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::OK); + } + + // Trigger abandon via the HTTP endpoint + HttpClient::Response AbandonResp = Client.Post("/abandon"sv); + REQUIRE_MESSAGE(AbandonResp.StatusCode == HttpResponseCode::OK, + fmt::format("Abandon request failed: status={}, body={}", int(AbandonResp.StatusCode), AbandonResp.ToText())); + + // Ready endpoint should now return 503 + { + HttpClient::Response ReadyResp = Client.Get("/ready"sv); + CHECK(ReadyResp.StatusCode == HttpResponseCode::ServiceUnavailable); + } + + // The abandoned action should appear in the completed endpoint once the process exits + const std::string CompletedUrl = fmt::format("/queues/{}/completed", QueueId); + REQUIRE_MESSAGE(PollForLsnInCompleted(Client, CompletedUrl, Lsn), + fmt::format("LSN {} did not appear in queue {} completed list after abandon\nServer log:\n{}", + Lsn, + QueueId, + Instance.GetLogOutput())); + + // Verify the queue reflects one abandoned action + const std::string QueueUrl = fmt::format("/queues/{}", QueueId); + HttpClient::Response StatusResp = Client.Get(QueueUrl); + REQUIRE_MESSAGE(StatusResp, "Failed to get queue status after abandon"); + + CbObject QueueStatus = StatusResp.AsObject(); + CHECK_EQ(QueueStatus["abandoned_count"sv].AsInt32(), 1); + CHECK_EQ(QueueStatus["completed_count"sv].AsInt32(), 0); + CHECK_EQ(QueueStatus["active_count"sv].AsInt32(), 0); + + // Submitting new work should be rejected + HttpClient::Response RejectedResp = Client.Post(JobUrl, BuildRot13ActionPackage("rejected"sv)); + CHECK_MESSAGE(RejectedResp.StatusCode != HttpResponseCode::OK, "Expected action submission to be rejected in Abandoned state"); +} + +TEST_CASE("function.session.abandon_pending") +{ + // Create a session with no runners so actions stay pending + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Enqueue several actions — they will stay pending because there are no runners + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + + CbObject ActionObj = BuildRot13ActionForSession("abandon-test"sv, Resolver); + + auto Enqueue1 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue2 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + auto Enqueue3 = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + REQUIRE_MESSAGE(Enqueue1, "Failed to enqueue action 1"); + REQUIRE_MESSAGE(Enqueue2, "Failed to enqueue action 2"); + REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); + + // Transition to Abandoned — should mark all pending actions as Abandoned + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK(!Session.IsHealthy()); + + // Give the scheduler thread time to process the state changes + Sleep(2'000); + + // All three actions should now be in the results map as abandoned + for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn}) + { + CbPackage Result; + HttpResponseCode Code = Session.GetActionResult(Lsn, Result); + CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); + } + + // Queue should show 0 active, 3 abandoned + auto Status = Session.GetQueueStatus(QueueResult.QueueId); + CHECK_EQ(Status.ActiveCount, 0); + CHECK_EQ(Status.AbandonedCount, 3); + + // New actions should be rejected + auto Rejected = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected in Abandoned state"); + + // Abandoned → Sunset should be valid + CHECK(Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Sunset)); + + Session.Shutdown(); +} + +TEST_CASE("function.session.abandon_running") +{ + // Spawn a real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + CHECK(!Session.IsHealthy()); + + // Poll for the action to complete (as abandoned) + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the queue shows abandoned, not completed + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + CHECK_EQ(QueueStatus.CompletedCount, 0); + + Session.Shutdown(); +} + +TEST_CASE("function.remote.abandon_propagation") +{ + // Spawn real zenserver as a remote compute node + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + // Start mock orchestrator advertising the server + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + // Create session infrastructure + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + + // Register worker on session + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for scheduler to discover the runner + Sleep(7'000); + + // Create a local queue and submit a long-running Sleep action + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Transition to Abandoned — should abandon the running action and propagate + bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); + + // Poll for the action to complete + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Action did not complete within timeout\nServer log:\n{}", Instance.GetLogOutput())); + + // Verify the local queue shows abandoned + auto QueueStatus = Session.GetQueueStatus(QueueId); + CHECK_EQ(QueueStatus.ActiveCount, 0); + CHECK_EQ(QueueStatus.AbandonedCount, 1); + + // Session should not be healthy + CHECK(!Session.IsHealthy()); + + // The remote compute node should still be healthy (only the parent abandoned) + HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); + HttpClient::Response ReadyResp = RemoteClient.Get("/ready"sv); + CHECK_MESSAGE(ReadyResp.StatusCode == HttpResponseCode::OK, "Remote compute node should still be healthy"); + + Session.Shutdown(); +} + +TEST_SUITE_END(); + +} // namespace zen::tests::compute + +#endif diff --git a/src/zenserver-test/function-tests.cpp b/src/zenserver-test/function-tests.cpp deleted file mode 100644 index 82848c6ad..000000000 --- a/src/zenserver-test/function-tests.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include - -#if ZEN_WITH_TESTS - -# include -# include -# include -# include -# include - -# include "zenserver-test.h" - -namespace zen::tests { - -using namespace std::literals; - -TEST_SUITE_BEGIN("server.function"); - -TEST_CASE("function.run") -{ - std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); - - ZenServerInstance Instance(TestEnv); - Instance.SetDataDir(TestDir); - Instance.SpawnServer(13337); - - ZEN_INFO("Waiting..."); - - Instance.WaitUntilReady(); -} - -TEST_SUITE_END(); - -} // namespace zen::tests - -#endif diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp new file mode 100644 index 000000000..fe39e14c0 --- /dev/null +++ b/src/zenserver-test/logging-tests.cpp @@ -0,0 +1,257 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_WITH_TESTS + +# include "zenserver-test.h" + +# include +# include +# include +# include + +namespace zen::tests { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +static bool +LogContains(const std::string& Log, std::string_view Needle) +{ + return Log.find(Needle) != std::string::npos; +} + +static std::string +ReadFileToString(const std::filesystem::path& Path) +{ + FileContents Contents = ReadFile(Path); + if (Contents.ErrorCode) + { + return {}; + } + + IoBuffer Content = Contents.Flatten(); + if (!Content) + { + return {}; + } + + return std::string(static_cast(Content.Data()), Content.Size()); +} + +////////////////////////////////////////////////////////////////////////// + +// Verify that a log file is created at the default location (DataDir/logs/zenserver.log) +// even without --abslog. The file must contain "server session id" (logged at INFO +// to all registered loggers during init) and "log starting at" (emitted once a file +// sink is first opened). +TEST_CASE("logging.file.default") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path DefaultLogFile = TestDir / "logs" / "zenserver.log"; + CHECK_MESSAGE(std::filesystem::exists(DefaultLogFile), "Default log file was not created"); + const std::string FileLog = ReadFileToString(DefaultLogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --quiet sets the console sink level to WARN. The formatted "[info] ..." +// entry written by the default logger's console sink must therefore not appear +// in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_* +// macros — may still emit plain-text messages without a level marker, so we +// check for the absence of the full_formatter "[info]" prefix rather than the +// message text itself.) +TEST_CASE("logging.console.quiet") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--quiet"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "[info] server session id"), Log); +} + +// --noconsole removes the stdout sink entirely, so the captured console output +// must not contain any log entries from the logging system. +TEST_CASE("logging.console.disabled") +{ + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--noconsole"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::string Log = Instance.GetLogOutput(); + CHECK_MESSAGE(!LogContains(Log, "server session id"), Log); +} + +// --abslog creates a rotating log file at the specified path. +// The file must contain "server session id" (logged at INFO to all loggers +// during init) and "log starting at" (emitted once a file sink is active). +TEST_CASE("logging.file.basic") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "log starting at"), FileLog); +} + +// --abslog with a .json extension selects the JSON formatter. +// Each log entry must be a JSON object containing at least the "message" +// and "source" fields. +TEST_CASE("logging.file.json") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "\"message\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "\"source\": \"zenserver\""), FileLog); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); +} + +// --log-id is automatically set to the server instance name in test mode. +// The JSON formatter emits this value as the "id" field, so every entry in a +// .json log file must carry a non-empty "id". +TEST_CASE("logging.log_id") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.json"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {}", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "JSON log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + // The JSON formatter writes the log-id as: "id": "", + CHECK_MESSAGE(LogContains(FileLog, "\"id\": \""), FileLog); +} + +// --log-warn raises the level threshold above INFO so that INFO messages +// are filtered. "server session id" is broadcast at INFO to all loggers: it must +// appear in the main file sink (default logger unaffected) but must NOT appear in +// http.log where the http_requests logger now has a WARN threshold. +TEST_CASE("logging.level.warn_suppresses_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-warn http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-info sets an explicit INFO threshold. The INFO "server session id" +// broadcast must still land in http.log, confirming that INFO messages are not +// filtered when the logger level is exactly INFO. +TEST_CASE("logging.level.info_allows_info") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-info http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(LogContains(HttpLog, "server session id"), HttpLog); +} + +// --log-off silences a named logger entirely. +// "server session id" is broadcast at INFO to all registered loggers via +// spdlog::apply_all during init. When the "http_requests" logger is set to +// OFF its dedicated http.log file must not contain that message. +// The main file sink (via --abslog) must be unaffected. +TEST_CASE("logging.level.off_specific_logger") +{ + const std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const std::filesystem::path LogFile = TestDir / "test.log"; + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + + const std::string LogArg = fmt::format("--abslog {} --log-off http_requests", LogFile.string()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(LogArg); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + Instance.Shutdown(); + + // Main log file must still have the startup message + CHECK_MESSAGE(std::filesystem::exists(LogFile), "Log file was not created"); + const std::string FileLog = ReadFileToString(LogFile); + CHECK_MESSAGE(LogContains(FileLog, "server session id"), FileLog); + + // http.log is created by the RotatingFileSink but the logger is OFF, so + // the broadcast "server session id" message must not have been written to it + const std::filesystem::path HttpLogFile = TestDir / "logs" / "http.log"; + CHECK_MESSAGE(std::filesystem::exists(HttpLogFile), "http.log was not created"); + const std::string HttpLog = ReadFileToString(HttpLogFile); + CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); +} + +} // namespace zen::tests + +#endif diff --git a/src/zenserver-test/nomad-tests.cpp b/src/zenserver-test/nomad-tests.cpp new file mode 100644 index 000000000..6eb99bc3a --- /dev/null +++ b/src/zenserver-test/nomad-tests.cpp @@ -0,0 +1,126 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS && ZEN_WITH_NOMAD +# include "zenserver-test.h" +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +namespace zen::tests::nomad_tests { + +using namespace std::literals; + +TEST_CASE("nomad.client.lifecycle" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + zen::nomad::NomadTestClient Client("http://localhost:4646/"); + + // Submit a simple batch job that sleeps briefly +# if ZEN_PLATFORM_WINDOWS + auto Job = Client.SubmitJob("zen-test-job", "cmd.exe", {"/C", "timeout /t 10 /nobreak"}); +# else + auto Job = Client.SubmitJob("zen-test-job", "/bin/sleep", {"10"}); +# endif + REQUIRE(!Job.Id.empty()); + CHECK_EQ(Job.Status, "pending"); + + // Poll until the job is running (or dead) + { + Stopwatch Timer; + bool FoundRunning = false; + while (Timer.GetElapsedTimeMs() < 15000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "running") + { + FoundRunning = true; + break; + } + if (Status.Status == "dead") + { + break; + } + Sleep(500); + } + CHECK(FoundRunning); + } + + // Verify allocations exist + auto Allocs = Client.GetAllocations("zen-test-job"); + CHECK(!Allocs.empty()); + + // Stop the job + Client.StopJob("zen-test-job"); + + // Verify it reaches dead state + { + Stopwatch Timer; + bool FoundDead = false; + while (Timer.GetElapsedTimeMs() < 10000) + { + auto Status = Client.GetJobStatus("zen-test-job"); + if (Status.Status == "dead") + { + FoundDead = true; + break; + } + Sleep(500); + } + CHECK(FoundDead); + } + + NomadProc.StopNomadAgent(); +} + +TEST_CASE("nomad.provisioner.integration" * doctest::skip()) +{ + zen::nomad::NomadProcess NomadProc; + NomadProc.SpawnNomadAgent(); + + // Spawn zenserver in compute mode with Nomad provisioning enabled + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + + std::filesystem::path ZenServerPath = TestEnv.ProgramBaseDir() / "zenserver" ZEN_EXE_SUFFIX_LITERAL; + + std::string NomadArgs = fmt::format( + "--nomad-enabled=true" + " --nomad-server=http://localhost:4646" + " --nomad-driver=raw_exec" + " --nomad-binary-path={}" + " --nomad-max-cores=32" + " --nomad-cores-per-job=32", + ZenServerPath.string()); + + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady(NomadArgs); + REQUIRE(Port != 0); + + // Give the provisioner time to submit jobs. + // The management thread has a 5s wait between cycles, and the HTTP client has + // a 10s connect timeout, so we need to allow enough time for at least one full cycle. + Sleep(15000); + + // Verify jobs were submitted to Nomad + zen::nomad::NomadTestClient NomadClient("http://localhost:4646/"); + + auto Jobs = NomadClient.ListJobs("zenserver-worker"); + + ZEN_INFO("nomad.provisioner.integration: found {} jobs with prefix 'zenserver-worker'", Jobs.size()); + CHECK_MESSAGE(!Jobs.empty(), Instance.GetLogOutput()); + + Instance.Shutdown(); + NomadProc.StopNomadAgent(); +} + +} // namespace zen::tests::nomad_tests +#endif diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua index 2a269cea1..7b208bbc7 100644 --- a/src/zenserver-test/xmake.lua +++ b/src/zenserver-test/xmake.lua @@ -6,10 +6,15 @@ target("zenserver-test") add_headerfiles("**.h") add_files("*.cpp") add_files("zenserver-test.cpp", {unity_ignored = true }) - add_deps("zencore", "zenremotestore", "zenhttp") + add_deps("zencore", "zenremotestore", "zenhttp", "zencompute", "zenstore") add_deps("zenserver", {inherit=false}) + add_deps("zentest-appstub", {inherit=false}) add_packages("http_parser") + if has_config("zennomad") then + add_deps("zennomad") + end + if is_plat("macosx") then add_ldflags("-framework CoreFoundation") add_ldflags("-framework Security") diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 0f9ef0287..802d06caf 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "computeserver.h" -#include -#include "computeservice.h" - +#include +#include +#include #if ZEN_WITH_COMPUTE_SERVICES # include @@ -13,10 +13,20 @@ # include # include # include +# include # include +# include # include # include # include +# if ZEN_WITH_HORDE +# include +# include +# endif +# if ZEN_WITH_NOMAD +# include +# include +# endif ZEN_THIRD_PARTY_INCLUDES_START # include @@ -27,6 +37,13 @@ namespace zen { void ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) { + Options.add_option("compute", + "", + "max-actions", + "Maximum number of concurrent local actions (0 = auto)", + cxxopts::value(m_ServerOptions.MaxConcurrentActions)->default_value("0"), + ""); + Options.add_option("compute", "", "upstream-notification-endpoint", @@ -40,6 +57,236 @@ ZenComputeServerConfigurator::AddCliOptions(cxxopts::Options& Options) "Instance ID for use in notifications", cxxopts::value(m_ServerOptions.InstanceId)->default_value(""), ""); + + Options.add_option("compute", + "", + "coordinator-endpoint", + "Endpoint URL for coordinator service", + cxxopts::value(m_ServerOptions.CoordinatorEndpoint)->default_value(""), + ""); + + Options.add_option("compute", + "", + "idms", + "Enable IDMS cloud detection; optionally specify a custom probe endpoint", + cxxopts::value(m_ServerOptions.IdmsEndpoint)->default_value("")->implicit_value("auto"), + ""); + + Options.add_option("compute", + "", + "worker-websocket", + "Use WebSocket for worker-orchestrator link (instant reachability detection)", + cxxopts::value(m_ServerOptions.EnableWorkerWebSocket)->default_value("false"), + ""); + +# if ZEN_WITH_HORDE + // Horde provisioning options + Options.add_option("horde", + "", + "horde-enabled", + "Enable Horde worker provisioning", + cxxopts::value(m_ServerOptions.HordeConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("horde", + "", + "horde-server", + "Horde server URL", + cxxopts::value(m_ServerOptions.HordeConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-token", + "Horde authentication token", + cxxopts::value(m_ServerOptions.HordeConfig.AuthToken)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-pool", + "Horde pool name", + cxxopts::value(m_ServerOptions.HordeConfig.Pool)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-cluster", + "Horde cluster ID ('default' or '_auto' for auto-resolve)", + cxxopts::value(m_ServerOptions.HordeConfig.Cluster)->default_value("default"), + ""); + + Options.add_option("horde", + "", + "horde-mode", + "Horde connection mode (direct, tunnel, relay)", + cxxopts::value(m_HordeModeStr)->default_value("direct"), + ""); + + Options.add_option("horde", + "", + "horde-encryption", + "Horde transport encryption (none, aes)", + cxxopts::value(m_HordeEncryptionStr)->default_value("none"), + ""); + + Options.add_option("horde", + "", + "horde-max-cores", + "Maximum number of Horde cores to provision", + cxxopts::value(m_ServerOptions.HordeConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("horde", + "", + "horde-host", + "Host address for Horde agents to connect back to", + cxxopts::value(m_ServerOptions.HordeConfig.HostAddress)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-condition", + "Additional Horde agent filter condition", + cxxopts::value(m_ServerOptions.HordeConfig.Condition)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-binaries", + "Path to directory containing zenserver binary for remote upload", + cxxopts::value(m_ServerOptions.HordeConfig.BinariesPath)->default_value(""), + ""); + + Options.add_option("horde", + "", + "horde-zen-service-port", + "Port number for Zen service communication", + cxxopts::value(m_ServerOptions.HordeConfig.ZenServicePort)->default_value("8558"), + ""); +# endif + +# if ZEN_WITH_NOMAD + // Nomad provisioning options + Options.add_option("nomad", + "", + "nomad-enabled", + "Enable Nomad worker provisioning", + cxxopts::value(m_ServerOptions.NomadConfig.Enabled)->default_value("false"), + ""); + + Options.add_option("nomad", + "", + "nomad-server", + "Nomad HTTP API URL", + cxxopts::value(m_ServerOptions.NomadConfig.ServerUrl)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-token", + "Nomad ACL token", + cxxopts::value(m_ServerOptions.NomadConfig.AclToken)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-datacenter", + "Nomad target datacenter", + cxxopts::value(m_ServerOptions.NomadConfig.Datacenter)->default_value("dc1"), + ""); + + Options.add_option("nomad", + "", + "nomad-namespace", + "Nomad namespace", + cxxopts::value(m_ServerOptions.NomadConfig.Namespace)->default_value("default"), + ""); + + Options.add_option("nomad", + "", + "nomad-region", + "Nomad region (empty for server default)", + cxxopts::value(m_ServerOptions.NomadConfig.Region)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-driver", + "Nomad task driver (raw_exec, docker)", + cxxopts::value(m_NomadDriverStr)->default_value("raw_exec"), + ""); + + Options.add_option("nomad", + "", + "nomad-distribution", + "Binary distribution mode (predeployed, artifact)", + cxxopts::value(m_NomadDistributionStr)->default_value("predeployed"), + ""); + + Options.add_option("nomad", + "", + "nomad-binary-path", + "Path to zenserver on Nomad clients (predeployed mode)", + cxxopts::value(m_ServerOptions.NomadConfig.BinaryPath)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-artifact-source", + "URL to download zenserver binary (artifact mode)", + cxxopts::value(m_ServerOptions.NomadConfig.ArtifactSource)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-docker-image", + "Docker image for zenserver (docker driver)", + cxxopts::value(m_ServerOptions.NomadConfig.DockerImage)->default_value(""), + ""); + + Options.add_option("nomad", + "", + "nomad-max-jobs", + "Maximum concurrent Nomad jobs", + cxxopts::value(m_ServerOptions.NomadConfig.MaxJobs)->default_value("64"), + ""); + + Options.add_option("nomad", + "", + "nomad-cpu-mhz", + "CPU MHz allocated per Nomad task", + cxxopts::value(m_ServerOptions.NomadConfig.CpuMhz)->default_value("1000"), + ""); + + Options.add_option("nomad", + "", + "nomad-memory-mb", + "Memory MB allocated per Nomad task", + cxxopts::value(m_ServerOptions.NomadConfig.MemoryMb)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-cores-per-job", + "Estimated cores per Nomad job (for scaling)", + cxxopts::value(m_ServerOptions.NomadConfig.CoresPerJob)->default_value("32"), + ""); + + Options.add_option("nomad", + "", + "nomad-max-cores", + "Maximum total cores to provision via Nomad", + cxxopts::value(m_ServerOptions.NomadConfig.MaxCores)->default_value("2048"), + ""); + + Options.add_option("nomad", + "", + "nomad-job-prefix", + "Prefix for generated Nomad job IDs", + cxxopts::value(m_ServerOptions.NomadConfig.JobPrefix)->default_value("zenserver-worker"), + ""); +# endif } void @@ -63,6 +310,15 @@ ZenComputeServerConfigurator::OnConfigFileParsed(LuaConfig::Options& LuaOptions) void ZenComputeServerConfigurator::ValidateOptions() { +# if ZEN_WITH_HORDE + horde::FromString(m_ServerOptions.HordeConfig.Mode, m_HordeModeStr); + horde::FromString(m_ServerOptions.HordeConfig.EncryptionMode, m_HordeEncryptionStr); +# endif + +# if ZEN_WITH_NOMAD + nomad::FromString(m_ServerOptions.NomadConfig.TaskDriver, m_NomadDriverStr); + nomad::FromString(m_ServerOptions.NomadConfig.BinDistribution, m_NomadDistributionStr); +# endif } /////////////////////////////////////////////////////////////////////////// @@ -90,10 +346,14 @@ ZenComputeServer::Initialize(const ZenComputeServerConfig& ServerConfig, ZenServ return EffectiveBasePort; } + m_CoordinatorEndpoint = ServerConfig.CoordinatorEndpoint; + m_InstanceId = ServerConfig.InstanceId; + m_EnableWorkerWebSocket = ServerConfig.EnableWorkerWebSocket; + // This is a workaround to make sure we can have automated tests. Without // this the ranges for different child zen compute processes could overlap with // the main test range. - ZenServerEnvironment::SetBaseChildId(1000); + ZenServerEnvironment::SetBaseChildId(2000); m_DebugOptionForcedCrash = ServerConfig.ShouldCrash; @@ -113,6 +373,46 @@ ZenComputeServer::Cleanup() ZEN_INFO(ZEN_APP_NAME " cleaning up"); try { + // Cancel the maintenance timer so it stops re-enqueuing before we + // tear down the provisioners it references. + m_ProvisionerMaintenanceTimer.cancel(); + m_AnnounceTimer.cancel(); + +# if ZEN_WITH_HORDE + // Shut down Horde provisioner first — this signals all agent threads + // to exit and joins them before we tear down HTTP services. + m_HordeProvisioner.reset(); +# endif + +# if ZEN_WITH_NOMAD + // Shut down Nomad provisioner — stops the management thread and + // sends stop requests for all tracked jobs. + m_NomadProvisioner.reset(); +# endif + + // Close the orchestrator WebSocket client before stopping the io_context + m_WsReconnectTimer.cancel(); + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + + ResolveCloudMetadata(); + m_CloudMetadata.reset(); + + // Shut down services that own threads or use the io_context before we + // stop the io_context and close the HTTP server. + if (m_OrchestratorService) + { + m_OrchestratorService->Shutdown(); + } + if (m_ComputeService) + { + m_ComputeService->Shutdown(); + } + m_IoContext.stop(); if (m_IoRunner.joinable()) { @@ -139,7 +439,8 @@ ZenComputeServer::InitializeState(const ZenComputeServerConfig& ServerConfig) void ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) { - ZEN_INFO("initializing storage"); + ZEN_TRACE_CPU("ZenComputeServer::InitializeServices"); + ZEN_INFO("initializing compute services"); CidStoreConfiguration Config; Config.RootDirectory = m_DataRoot / "cas"; @@ -147,46 +448,405 @@ ZenComputeServer::InitializeServices(const ZenComputeServerConfig& ServerConfig) m_CidStore = std::make_unique(m_GcManager); m_CidStore->Initialize(Config); + if (!ServerConfig.IdmsEndpoint.empty()) + { + ZEN_INFO("detecting cloud environment (async)"); + if (ServerConfig.IdmsEndpoint == "auto") + { + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir] { + return std::make_unique(DataDir / "cloud"); + }); + } + else + { + ZEN_INFO("using custom IDMS endpoint: {}", ServerConfig.IdmsEndpoint); + m_CloudMetadataFuture = std::async(std::launch::async, [DataDir = ServerConfig.DataDir, Endpoint = ServerConfig.IdmsEndpoint] { + return std::make_unique(DataDir / "cloud", Endpoint); + }); + } + } + ZEN_INFO("instantiating API service"); m_ApiService = std::make_unique(*m_Http); - ZEN_INFO("instantiating compute service"); - m_ComputeService = std::make_unique(ServerConfig.DataDir / "compute"); + ZEN_INFO("instantiating orchestrator service"); + m_OrchestratorService = + std::make_unique(ServerConfig.DataDir / "orch", ServerConfig.EnableWorkerWebSocket); + + ZEN_INFO("instantiating function service"); + m_ComputeService = std::make_unique(*m_CidStore, + m_StatsService, + ServerConfig.DataDir / "functions", + ServerConfig.MaxConcurrentActions); - // Ref Runner; - // Runner = zen::compute::CreateLocalRunner(*m_CidStore, ServerConfig.DataDir / "runner"); + m_FrontendService = std::make_unique(m_ContentRoot, m_StatusService); - // TODO: (re)implement default configuration here +# if ZEN_WITH_NOMAD + // Nomad provisioner + if (ServerConfig.NomadConfig.Enabled && !ServerConfig.NomadConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Nomad provisioner (server: {})", ServerConfig.NomadConfig.ServerUrl); - ZEN_INFO("instantiating function service"); - m_FunctionService = - std::make_unique(*m_CidStore, m_StatsService, ServerConfig.DataDir / "functions"); + const auto& NomadCfg = ServerConfig.NomadConfig; + + if (!NomadCfg.Validate()) + { + ZEN_ERROR("invalid Nomad configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + m_NomadProvisioner = std::make_unique(NomadCfg, OrchestratorEndpoint); + } + } +# endif + +# if ZEN_WITH_HORDE + // Horde provisioner + if (ServerConfig.HordeConfig.Enabled && !ServerConfig.HordeConfig.ServerUrl.empty()) + { + ZEN_INFO("instantiating Horde provisioner (server: {})", ServerConfig.HordeConfig.ServerUrl); + + const auto& HordeConfig = ServerConfig.HordeConfig; + + if (!HordeConfig.Validate()) + { + ZEN_ERROR("invalid Horde configuration"); + } + else + { + ExtendableStringBuilder<256> OrchestratorEndpoint; + OrchestratorEndpoint << m_Http->GetServiceUri(m_OrchestratorService.get()); + if (auto View = OrchestratorEndpoint.ToView(); !View.empty() && View.back() != '/') + { + OrchestratorEndpoint << '/'; + } + + // If no binaries path is specified, just use the running executable's directory + std::filesystem::path BinariesPath = HordeConfig.BinariesPath.empty() ? GetRunningExecutablePath().parent_path() + : std::filesystem::path(HordeConfig.BinariesPath); + std::filesystem::path WorkingDir = ServerConfig.DataDir / "horde"; + + m_HordeProvisioner = std::make_unique(HordeConfig, BinariesPath, WorkingDir, OrchestratorEndpoint); + } + } +# endif +} + +void +ZenComputeServer::ResolveCloudMetadata() +{ + if (m_CloudMetadataFuture.valid()) + { + m_CloudMetadata = m_CloudMetadataFuture.get(); + } +} + +std::string +ZenComputeServer::GetInstanceId() const +{ + if (!m_InstanceId.empty()) + { + return m_InstanceId; + } + return fmt::format("{}-{}", GetMachineName(), GetCurrentProcessId()); +} + +std::string +ZenComputeServer::GetAnnounceUrl() const +{ + return m_Http->GetServiceUri(nullptr); } void ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig) { + ZEN_TRACE_CPU("ZenComputeServer::RegisterServices"); ZEN_UNUSED(ServerConfig); + m_Http->RegisterService(m_StatsService); + + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + + if (m_OrchestratorService) + { + m_Http->RegisterService(*m_OrchestratorService); + } + if (m_ComputeService) { m_Http->RegisterService(*m_ComputeService); } - if (m_ApiService) + if (m_FrontendService) { - m_Http->RegisterService(*m_ApiService); + m_Http->RegisterService(*m_FrontendService); + } +} + +CbObject +ZenComputeServer::BuildAnnounceBody() +{ + CbObjectWriter AnnounceBody; + AnnounceBody << "id" << GetInstanceId(); + AnnounceBody << "uri" << GetAnnounceUrl(); + AnnounceBody << "hostname" << GetMachineName(); + AnnounceBody << "platform" << GetRuntimePlatformName(); + + ExtendedSystemMetrics Sm = ApplyReportingOverrides(m_MetricsTracker.Query()); + + AnnounceBody.BeginObject("metrics"); + Describe(Sm, AnnounceBody); + AnnounceBody.EndObject(); + + AnnounceBody << "cpu_usage" << Sm.CpuUsagePercent; + AnnounceBody << "memory_total" << Sm.SystemMemoryMiB * 1024 * 1024; + AnnounceBody << "memory_used" << (Sm.SystemMemoryMiB - Sm.AvailSystemMemoryMiB) * 1024 * 1024; + + AnnounceBody << "bytes_received" << m_Http->GetTotalBytesReceived(); + AnnounceBody << "bytes_sent" << m_Http->GetTotalBytesSent(); + + auto Actions = m_ComputeService->GetActionCounts(); + AnnounceBody << "actions_pending" << Actions.Pending; + AnnounceBody << "actions_running" << Actions.Running; + AnnounceBody << "actions_completed" << Actions.Completed; + AnnounceBody << "active_queues" << Actions.ActiveQueues; + + // Derive provisioner from instance ID prefix (e.g. "horde-xxx" or "nomad-xxx") + if (m_InstanceId.starts_with("horde-")) + { + AnnounceBody << "provisioner" + << "horde"; + } + else if (m_InstanceId.starts_with("nomad-")) + { + AnnounceBody << "provisioner" + << "nomad"; + } + + ResolveCloudMetadata(); + if (m_CloudMetadata) + { + m_CloudMetadata->Describe(AnnounceBody); + } + + return AnnounceBody.Save(); +} + +void +ZenComputeServer::PostAnnounce() +{ + ZEN_TRACE_CPU("ZenComputeServer::PostAnnounce"); + + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + ZEN_INFO("notifying coordinator at '{}' of our availability at '{}'", m_CoordinatorEndpoint, GetAnnounceUrl()); + + try + { + CbObject Body = BuildAnnounceBody(); + + // If we have an active WebSocket connection, send via that instead of HTTP POST + if (m_OrchestratorWsClient && m_OrchestratorWsClient->IsOpen()) + { + MemoryView View = Body.GetView(); + m_OrchestratorWsClient->SendBinary(std::span(reinterpret_cast(View.GetData()), View.GetSize())); + ZEN_INFO("announced to coordinator via WebSocket"); + return; + } + + HttpClient CoordinatorHttp(m_CoordinatorEndpoint); + HttpClient::Response Result = CoordinatorHttp.Post("announce", std::move(Body)); + + if (Result.Error) + { + ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}", + m_CoordinatorEndpoint, + Result.Error->ErrorCode, + Result.Error->ErrorMessage); + } + else if (!IsHttpOk(Result.StatusCode)) + { + ZEN_ERROR("failed to notify coordinator at '{}': unexpected HTTP status code {}", + m_CoordinatorEndpoint, + static_cast(Result.StatusCode)); + } + else + { + ZEN_INFO("successfully notified coordinator at '{}'", m_CoordinatorEndpoint); + } + } + catch (const std::exception& Ex) + { + ZEN_ERROR("failed to notify coordinator at '{}': {}", m_CoordinatorEndpoint, Ex.what()); + } +} + +void +ZenComputeServer::EnqueueAnnounceTimer() +{ + if (!m_ComputeService || m_CoordinatorEndpoint.empty()) + { + return; + } + + m_AnnounceTimer.expires_after(std::chrono::seconds(15)); + m_AnnounceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + PostAnnounce(); + EnqueueAnnounceTimer(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::InitializeOrchestratorWebSocket() +{ + if (!m_EnableWorkerWebSocket || m_CoordinatorEndpoint.empty()) + { + return; + } + + // Convert http://host:port → ws://host:port/orch/ws + std::string WsUrl = m_CoordinatorEndpoint; + if (WsUrl.starts_with("http://")) + { + WsUrl = "ws://" + WsUrl.substr(7); + } + else if (WsUrl.starts_with("https://")) + { + WsUrl = "wss://" + WsUrl.substr(8); + } + if (!WsUrl.empty() && WsUrl.back() != '/') + { + WsUrl += '/'; + } + WsUrl += "orch/ws"; + + ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl); + + m_OrchestratorWsHandler = std::make_unique(*this); + m_OrchestratorWsClient = + std::make_unique(WsUrl, *m_OrchestratorWsHandler, m_IoContext, HttpWsClientSettings{.LogCategory = "orch_ws"}); + + m_OrchestratorWsClient->Connect(); + EnsureIoRunner(); +} + +void +ZenComputeServer::EnqueueWsReconnect() +{ + m_WsReconnectTimer.expires_after(std::chrono::seconds(5)); + m_WsReconnectTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec && m_OrchestratorWsClient) + { + ZEN_INFO("attempting WebSocket reconnect to orchestrator"); + m_OrchestratorWsClient->Connect(); + } + }); + EnsureIoRunner(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsOpen() +{ + ZEN_INFO("WebSocket link to orchestrator established"); + + // Send initial announce immediately over the WebSocket + Server.PostAnnounce(); +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // Orchestrator does not push messages to workers; ignore +} + +void +ZenComputeServer::OrchestratorWsHandler::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket link to orchestrator closed (code {}), falling back to HTTP announce", Code); + + // Trigger an immediate HTTP announce so the orchestrator has fresh state, + // then schedule a reconnect attempt. + Server.PostAnnounce(); + Server.EnqueueWsReconnect(); +} + +void +ZenComputeServer::ProvisionerMaintenanceTick() +{ +# if ZEN_WITH_HORDE + if (m_HordeProvisioner) + { + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_DEBUG("Horde maintenance: target={}, estimated={}, active={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.ActiveCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_DEBUG("Nomad maintenance: target={}, estimated={}, running jobs={}", + Stats.TargetCoreCount, + Stats.EstimatedCoreCount, + Stats.RunningJobCount); } +# endif +} + +void +ZenComputeServer::EnqueueProvisionerMaintenanceTimer() +{ + bool HasProvisioner = false; +# if ZEN_WITH_HORDE + HasProvisioner = HasProvisioner || (m_HordeProvisioner != nullptr); +# endif +# if ZEN_WITH_NOMAD + HasProvisioner = HasProvisioner || (m_NomadProvisioner != nullptr); +# endif - if (m_FunctionService) + if (!HasProvisioner) { - m_Http->RegisterService(*m_FunctionService); + return; } + + m_ProvisionerMaintenanceTimer.expires_after(std::chrono::seconds(15)); + m_ProvisionerMaintenanceTimer.async_wait([this](const asio::error_code& Ec) { + if (!Ec) + { + ProvisionerMaintenanceTick(); + EnqueueProvisionerMaintenanceTimer(); + } + }); + EnsureIoRunner(); } void ZenComputeServer::Run() { + ZEN_TRACE_CPU("ZenComputeServer::Run"); + if (m_ProcessMonitor.IsActive()) { CheckOwnerPid(); @@ -236,6 +896,35 @@ ZenComputeServer::Run() OnReady(); + PostAnnounce(); + EnqueueAnnounceTimer(); + InitializeOrchestratorWebSocket(); + +# if ZEN_WITH_HORDE + // Start Horde provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to HordeConfig::MaxCores internally. + if (m_HordeProvisioner) + { + ZEN_INFO("Horde provisioning starting"); + m_HordeProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_HordeProvisioner->GetStats(); + ZEN_INFO("Horde provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + +# if ZEN_WITH_NOMAD + // Start Nomad provisioning if configured — request maximum allowed cores. + // SetTargetCoreCount clamps to NomadConfig::MaxCores internally. + if (m_NomadProvisioner) + { + m_NomadProvisioner->SetTargetCoreCount(UINT32_MAX); + auto Stats = m_NomadProvisioner->GetStats(); + ZEN_INFO("Nomad provisioning started (target cores: {})", Stats.TargetCoreCount); + } +# endif + + EnqueueProvisionerMaintenanceTimer(); + m_Http->Run(IsInteractiveMode); SetNewState(kShuttingDown); @@ -254,6 +943,8 @@ ZenComputeServerMain::ZenComputeServerMain(ZenComputeServerConfig& ServerOptions void ZenComputeServerMain::DoRun(ZenServerState::ZenServerEntry* Entry) { + ZEN_TRACE_CPU("ZenComputeServerMain::DoRun"); + ZenComputeServer Server; Server.SetDataRoot(m_ServerOptions.DataDir); Server.SetContentRoot(m_ServerOptions.ContentDir); diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index 625140b23..e4a6b01d5 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -6,7 +6,11 @@ #if ZEN_WITH_COMPUTE_SERVICES +# include +# include +# include # include +# include "frontend/frontend.h" namespace cxxopts { class Options; @@ -16,19 +20,46 @@ struct Options; } namespace zen::compute { -class HttpFunctionService; -} +class CloudMetadata; +class HttpComputeService; +class HttpOrchestratorService; +} // namespace zen::compute + +# if ZEN_WITH_HORDE +# include +namespace zen::horde { +class HordeProvisioner; +} // namespace zen::horde +# endif + +# if ZEN_WITH_NOMAD +# include +namespace zen::nomad { +class NomadProvisioner; +} // namespace zen::nomad +# endif namespace zen { class CidStore; class HttpApiService; -class HttpComputeService; struct ZenComputeServerConfig : public ZenServerConfig { std::string UpstreamNotificationEndpoint; std::string InstanceId; // For use in notifications + std::string CoordinatorEndpoint; + std::string IdmsEndpoint; + int32_t MaxConcurrentActions = 0; // 0 = auto (LogicalProcessorCount * 2) + bool EnableWorkerWebSocket = false; // Use WebSocket for worker↔orchestrator link + +# if ZEN_WITH_HORDE + horde::HordeConfig HordeConfig; +# endif + +# if ZEN_WITH_NOMAD + nomad::NomadConfig NomadConfig; +# endif }; struct ZenComputeServerConfigurator : public ZenServerConfiguratorBase @@ -49,6 +80,16 @@ private: virtual void ValidateOptions() override; ZenComputeServerConfig& m_ServerOptions; + +# if ZEN_WITH_HORDE + std::string m_HordeModeStr = "direct"; + std::string m_HordeEncryptionStr = "none"; +# endif + +# if ZEN_WITH_NOMAD + std::string m_NomadDriverStr = "raw_exec"; + std::string m_NomadDistributionStr = "predeployed"; +# endif }; class ZenComputeServerMain : public ZenServerMain @@ -88,17 +129,59 @@ public: void Cleanup(); private: - HttpStatsService m_StatsService; - GcManager m_GcManager; - GcScheduler m_GcScheduler{m_GcManager}; - std::unique_ptr m_CidStore; - std::unique_ptr m_ComputeService; - std::unique_ptr m_ApiService; - std::unique_ptr m_FunctionService; - - void InitializeState(const ZenComputeServerConfig& ServerConfig); - void InitializeServices(const ZenComputeServerConfig& ServerConfig); - void RegisterServices(const ZenComputeServerConfig& ServerConfig); + HttpStatsService m_StatsService; + GcManager m_GcManager; + GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr m_CidStore; + std::unique_ptr m_ApiService; + std::unique_ptr m_ComputeService; + std::unique_ptr m_OrchestratorService; + std::unique_ptr m_CloudMetadata; + std::future> m_CloudMetadataFuture; + std::unique_ptr m_FrontendService; +# if ZEN_WITH_HORDE + std::unique_ptr m_HordeProvisioner; +# endif +# if ZEN_WITH_NOMAD + std::unique_ptr m_NomadProvisioner; +# endif + SystemMetricsTracker m_MetricsTracker; + std::string m_CoordinatorEndpoint; + std::string m_InstanceId; + + asio::steady_timer m_AnnounceTimer{m_IoContext}; + asio::steady_timer m_ProvisionerMaintenanceTimer{m_IoContext}; + + void InitializeState(const ZenComputeServerConfig& ServerConfig); + void InitializeServices(const ZenComputeServerConfig& ServerConfig); + void RegisterServices(const ZenComputeServerConfig& ServerConfig); + void ResolveCloudMetadata(); + void PostAnnounce(); + void EnqueueAnnounceTimer(); + void EnqueueProvisionerMaintenanceTimer(); + void ProvisionerMaintenanceTick(); + std::string GetAnnounceUrl() const; + std::string GetInstanceId() const; + CbObject BuildAnnounceBody(); + + // Worker→orchestrator WebSocket client + struct OrchestratorWsHandler : public IWsClientHandler + { + ZenComputeServer& Server; + explicit OrchestratorWsHandler(ZenComputeServer& S) : Server(S) {} + + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + }; + + std::unique_ptr m_OrchestratorWsHandler; + std::unique_ptr m_OrchestratorWsClient; + asio::steady_timer m_WsReconnectTimer{m_IoContext}; + bool m_EnableWorkerWebSocket = false; + + void InitializeOrchestratorWebSocket(); + void EnqueueWsReconnect(); }; } // namespace zen diff --git a/src/zenserver/compute/computeservice.cpp b/src/zenserver/compute/computeservice.cpp deleted file mode 100644 index 2c0bc0ae9..000000000 --- a/src/zenserver/compute/computeservice.cpp +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "computeservice.h" - -#if ZEN_WITH_COMPUTE_SERVICES - -# include -# include -# include -# include -# include -# include - -ZEN_THIRD_PARTY_INCLUDES_START -# include -# include -ZEN_THIRD_PARTY_INCLUDES_END - -# include - -namespace zen { - -////////////////////////////////////////////////////////////////////////// - -struct ResourceMetrics -{ - uint64_t DiskUsageBytes = 0; - uint64_t MemoryUsageBytes = 0; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpComputeService::Impl -{ - Impl(const Impl&) = delete; - Impl& operator=(const Impl&) = delete; - - Impl(); - ~Impl(); - - void Initialize(std::filesystem::path BaseDir) { ZEN_UNUSED(BaseDir); } - - void Cleanup() {} - -private: -}; - -HttpComputeService::Impl::Impl() -{ -} - -HttpComputeService::Impl::~Impl() -{ -} - -/////////////////////////////////////////////////////////////////////////// - -HttpComputeService::HttpComputeService(std::filesystem::path BaseDir) : m_Impl(std::make_unique()) -{ - using namespace std::literals; - - m_Impl->Initialize(BaseDir); - - m_Router.RegisterRoute( - "status", - [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Obj.BeginArray("modules"); - Obj.EndArray(); - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); - }, - HttpVerb::kGet); - - m_Router.RegisterRoute( - "stats", - [this](HttpRouterRequest& Req) { - CbObjectWriter Obj; - Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); - }, - HttpVerb::kGet); -} - -HttpComputeService::~HttpComputeService() -{ -} - -const char* -HttpComputeService::BaseUri() const -{ - return "/compute/"; -} - -void -HttpComputeService::HandleRequest(zen::HttpServerRequest& Request) -{ - m_Router.HandleRequest(Request); -} - -} // namespace zen -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/computeservice.h b/src/zenserver/compute/computeservice.h deleted file mode 100644 index 339200dd8..000000000 --- a/src/zenserver/compute/computeservice.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if ZEN_WITH_COMPUTE_SERVICES -namespace zen { - -/** ZenServer Compute Service - * - * Manages a set of compute workers for use in UEFN content worker - * - */ -class HttpComputeService : public zen::HttpService -{ -public: - HttpComputeService(std::filesystem::path BaseDir); - ~HttpComputeService(); - - HttpComputeService(const HttpComputeService&) = delete; - HttpComputeService& operator=(const HttpComputeService&) = delete; - - virtual const char* BaseUri() const override; - virtual void HandleRequest(zen::HttpServerRequest& Request) override; - -private: - HttpRequestRouter m_Router; - - struct Impl; - - std::unique_ptr m_Impl; -}; - -} // namespace zen -#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index 4767029c0..c167cc70e 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ diff --git a/src/zenserver/frontend/html/404.html b/src/zenserver/frontend/html/404.html new file mode 100644 index 000000000..829ef2097 --- /dev/null +++ b/src/zenserver/frontend/html/404.html @@ -0,0 +1,486 @@ + + + + + +Ooops + + + + + + +
+ +
+ +
+

404 NOT FOUND

+
+ + + + + + diff --git a/src/zenserver/frontend/html/compute.html b/src/zenserver/frontend/html/compute.html deleted file mode 100644 index 668189fe5..000000000 --- a/src/zenserver/frontend/html/compute.html +++ /dev/null @@ -1,991 +0,0 @@ - - - - - - Zen Compute Dashboard - - - - -
-
-
-

Zen Compute Dashboard

-
Last updated: Never
-
-
-
- Checking... -
-
- -
- - -
Action Queue
-
-
-
Pending Actions
-
-
-
Waiting to be scheduled
-
-
-
Running Actions
-
-
-
Currently executing
-
-
-
Completed Actions
-
-
-
Results available
-
-
- - -
-
Action Queue History
-
- -
-
- - -
Performance Metrics
-
-
Completion Rate
-
-
-
-
-
1 min rate
-
-
-
-
-
5 min rate
-
-
-
-
-
15 min rate
-
-
-
-
- Total Retired - - -
-
- Mean Rate - - -
-
-
- - -
Workers
-
-
Worker Status
-
- Registered Workers - - -
- -
- - -
Recent Actions
-
-
Action History
-
No actions recorded yet.
- -
- - -
System Resources
-
-
-
CPU Usage
-
-
-
Percent
-
-
-
-
- -
-
-
- Packages - - -
-
- Physical Cores - - -
-
- Logical Processors - - -
-
-
-
-
Memory
-
- Used - - -
-
- Total - - -
-
-
-
-
-
-
Disk
-
- Used - - -
-
- Total - - -
-
-
-
-
-
-
- - - - diff --git a/src/zenserver/frontend/html/compute/banner.js b/src/zenserver/frontend/html/compute/banner.js new file mode 100644 index 000000000..61c7ce21f --- /dev/null +++ b/src/zenserver/frontend/html/compute/banner.js @@ -0,0 +1,321 @@ +/** + * zen-banner.js — Zen Compute dashboard banner Web Component + * + * Usage: + * + * + * + * + * + * + * Attributes: + * variant "full" (default) | "compact" + * cluster-status "nominal" (default) | "degraded" | "offline" + * load 0–100 integer, shown as a percentage (default: hidden) + * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator") + * subtitle text after "ZEN" in the wordmark (default: "COMPUTE") + */ + +class ZenBanner extends HTMLElement { + + static get observedAttributes() { + return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle']; + } + + attributeChangedCallback() { + if (this.shadowRoot) this._render(); + } + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + // ───────────────────────────────────────────── + // Derived values + // ───────────────────────────────────────────── + + get _variant() { return this.getAttribute('variant') || 'full'; } + get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); } + get _load() { return this.getAttribute('load'); } // null → hidden + get _tagline() { return this.getAttribute('tagline'); } // null → default + get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE" + + get _statusColor() { + return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8'; + } + + get _statusLabel() { + return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL'; + } + + get _loadColor() { + const v = parseInt(this._load, 10); + if (isNaN(v)) return '#7ecfb8'; + if (v >= 85) return '#c0504d'; + if (v >= 60) return '#d4a84b'; + return '#7ecfb8'; + } + + // ───────────────────────────────────────────── + // Render + // ───────────────────────────────────────────── + + _render() { + const compact = this._variant === 'compact'; + this.shadowRoot.innerHTML = ` + + ${this._html(compact)} + `; + } + + // ───────────────────────────────────────────── + // CSS + // ───────────────────────────────────────────── + + _css(compact) { + const height = compact ? '60px' : '100px'; + const padding = compact ? '0 24px' : '0 32px'; + const gap = compact ? '16px' : '24px'; + const markSize = compact ? '34px' : '52px'; + const divH = compact ? '32px' : '48px'; + const nameSize = compact ? '15px' : '22px'; + const tagSize = compact ? '9px' : '11px'; + const sc = this._statusColor; + const lc = this._loadColor; + + return ` + @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap'); + + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + font-family: 'Space Mono', monospace; + } + + .banner { + width: 100%; + height: ${height}; + background: #0b0d10; + border: 1px solid #1e2330; + border-radius: 6px; + display: flex; + align-items: center; + padding: ${padding}; + gap: ${gap}; + position: relative; + overflow: hidden; + } + + /* scan-line texture */ + .banner::before { + content: ''; + position: absolute; + inset: 0; + background: repeating-linear-gradient( + 0deg, + transparent, transparent 3px, + rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px + ); + pointer-events: none; + } + + /* ambient glow */ + .banner::after { + content: ''; + position: absolute; + right: -60px; + top: 50%; + transform: translateY(-50%); + width: 280px; + height: 280px; + background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%); + pointer-events: none; + } + + .logo-mark { + flex-shrink: 0; + width: ${markSize}; + height: ${markSize}; + } + + .logo-mark svg { width: 100%; height: 100%; } + + .divider { + width: 1px; + height: ${divH}; + background: linear-gradient(to bottom, transparent, #2a3040, transparent); + flex-shrink: 0; + } + + .text-block { + display: flex; + flex-direction: column; + gap: 4px; + } + + .wordmark { + font-weight: 700; + font-size: ${nameSize}; + letter-spacing: 0.12em; + color: #e8e4dc; + text-transform: uppercase; + line-height: 1; + } + + .wordmark span { color: #7ecfb8; } + + .tagline { + font-family: 'Noto Serif JP', serif; + font-weight: 300; + font-size: ${tagSize}; + letter-spacing: 0.3em; + color: #4a5a68; + text-transform: uppercase; + } + + .spacer { flex: 1; } + + /* ── right-side decorative circuit ── */ + .circuit { flex-shrink: 0; opacity: 0.22; } + + /* ── status cluster ── */ + .status-cluster { + display: flex; + flex-direction: column; + align-items: flex-end; + gap: 6px; + } + + .status-row { + display: flex; + align-items: center; + gap: 8px; + } + + .status-lbl { + font-size: 9px; + letter-spacing: 0.18em; + color: #3a4555; + text-transform: uppercase; + } + + .pill { + display: flex; + align-items: center; + gap: 5px; + border-radius: 20px; + padding: 2px 10px; + font-size: 10px; + letter-spacing: 0.1em; + } + + .pill.cluster { + color: ${sc}; + background: color-mix(in srgb, ${sc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${sc} 28%, transparent); + } + + .pill.load-pill { + color: ${lc}; + background: color-mix(in srgb, ${lc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${lc} 28%, transparent); + } + + .dot { + width: 5px; + height: 5px; + border-radius: 50%; + animation: pulse 2.4s ease-in-out infinite; + } + + .dot.cluster { background: ${sc}; } + .dot.load-dot { background: ${lc}; animation-delay: 0.5s; } + + @keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.25; } + } + `; + } + + // ───────────────────────────────────────────── + // HTML template + // ───────────────────────────────────────────── + + _html(compact) { + const loadAttr = this._load; + const showStatus = !compact; + + const rightSide = showStatus ? ` + + + + + + + + + +
+
+ Cluster +
+
+ ${this._statusLabel} +
+
+ ${loadAttr !== null ? ` +
+ Load +
+
+ ${parseInt(loadAttr, 10)} % +
+
` : ''} +
+ ` : ''; + + return ` + + `; + } + + // ───────────────────────────────────────────── + // SVG logo mark + // ───────────────────────────────────────────── + + _svgMark() { + return ` + + + + + + + + + + + + + + + + + + `; + } +} + +customElements.define('zen-banner', ZenBanner); diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html new file mode 100644 index 000000000..1e101d839 --- /dev/null +++ b/src/zenserver/frontend/html/compute/compute.html @@ -0,0 +1,1072 @@ + + + + + + Zen Compute Dashboard + + + + + + +
+ + + Node + Orchestrator + +
Last updated: Never
+ +
+ + +
Action Queue
+
+
+
Pending Actions
+
-
+
Waiting to be scheduled
+
+
+
Running Actions
+
-
+
Currently executing
+
+
+
Completed Actions
+
-
+
Results available
+
+
+ + +
+
Action Queue History
+
+ +
+
+ + +
Performance Metrics
+
+
Completion Rate
+
+
+
-
+
1 min rate
+
+
+
-
+
5 min rate
+
+
+
-
+
15 min rate
+
+
+
+
+ Total Retired + - +
+
+ Mean Rate + - +
+
+
+ + +
Workers
+
+
Worker Status
+
+ Registered Workers + - +
+ +
+ + +
Queues
+
+
Queue Status
+
No queues.
+ +
+ + +
Recent Actions
+
+
Action History
+
No actions recorded yet.
+ +
+ + +
System Resources
+
+
+
CPU Usage
+
-
+
Percent
+
+
+
+
+ +
+
+
+ Packages + - +
+
+ Physical Cores + - +
+
+ Logical Processors + - +
+
+
+
+
Memory
+
+ Used + - +
+
+ Total + - +
+
+
+
+
+
+
Disk
+
+ Used + - +
+
+ Total + - +
+
+
+
+
+
+
+ + + + diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html new file mode 100644 index 000000000..f66ba94d5 --- /dev/null +++ b/src/zenserver/frontend/html/compute/hub.html @@ -0,0 +1,310 @@ + + + + + + + + Zen Hub Dashboard + + + +
+ + + Hub + +
Last updated: Never
+ +
+ +
Capacity
+
+
+
Active Modules
+
-
+
Currently provisioned
+
+
+
Peak Modules
+
-
+
High watermark
+
+
+
Instance Limit
+
-
+
Maximum allowed
+
+
+
+
+
+ +
Modules
+
+
Storage Server Instances
+
No modules provisioned.
+ + + + + + + + + +
+
+ + + + diff --git a/src/zenserver/frontend/html/compute/index.html b/src/zenserver/frontend/html/compute/index.html new file mode 100644 index 000000000..9597fd7f3 --- /dev/null +++ b/src/zenserver/frontend/html/compute/index.html @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/zenserver/frontend/html/compute/nav.js b/src/zenserver/frontend/html/compute/nav.js new file mode 100644 index 000000000..8ec42abd0 --- /dev/null +++ b/src/zenserver/frontend/html/compute/nav.js @@ -0,0 +1,79 @@ +/** + * zen-nav.js — Zen dashboard navigation bar Web Component + * + * Usage: + * + * + * + * Node + * Orchestrator + * + * + * Each child becomes a nav link. The current page is + * highlighted automatically based on the href. + */ + +class ZenNav extends HTMLElement { + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + _render() { + const currentPath = window.location.pathname; + const items = Array.from(this.querySelectorAll(':scope > a')); + + const links = items.map(a => { + const href = a.getAttribute('href') || ''; + const label = a.textContent.trim(); + const active = currentPath.endsWith(href); + return `${label}`; + }).join(''); + + this.shadowRoot.innerHTML = ` + + + `; + } +} + +customElements.define('zen-nav', ZenNav); diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html new file mode 100644 index 000000000..2ee57b6b3 --- /dev/null +++ b/src/zenserver/frontend/html/compute/orchestrator.html @@ -0,0 +1,831 @@ + + + + + + + + Zen Orchestrator Dashboard + + + +
+ + + Node + Orchestrator + +
+
+
Last updated: Never
+
+
+ Agents: + - +
+
+ +
+ +
+
Compute Agents
+
No agents registered.
+ + + + + + + + + + + + + + + + + + +
+
+
Connected Clients
+
No clients connected.
+ + + + + + + + + + + + +
+
+
+
Event History
+
+ + +
+
+
+
No provisioning events recorded.
+ + + + + + + + + + + +
+ +
+
+ + + + diff --git a/src/zenserver/frontend/html/pages/page.js b/src/zenserver/frontend/html/pages/page.js index 3c2d3619a..592b699dc 100644 --- a/src/zenserver/frontend/html/pages/page.js +++ b/src/zenserver/frontend/html/pages/page.js @@ -3,6 +3,7 @@ "use strict"; import { WidgetHost } from "../util/widgets.js" +import { Fetcher } from "../util/fetcher.js" //////////////////////////////////////////////////////////////////////////////// export class PageBase extends WidgetHost @@ -63,6 +64,7 @@ export class ZenPage extends PageBase super(parent, ...args); super.set_title("zen"); this.add_branding(parent); + this.add_service_nav(parent); this.generate_crumbs(); } @@ -78,6 +80,40 @@ export class ZenPage extends PageBase root.tag("img").attr("src", "epicgames.ico").id("epic_logo"); } + add_service_nav(parent) + { + const nav = parent.tag().id("service_nav"); + + // Map service base URIs to dashboard links, this table is also used to detemine + // which links to show based on the services that are currently registered. + + const service_dashboards = [ + { base_uri: "/compute/", label: "Compute", href: "/dashboard/compute/compute.html" }, + { base_uri: "/orch/", label: "Orchestrator", href: "/dashboard/compute/orchestrator.html" }, + { base_uri: "/hub/", label: "Hub", href: "/dashboard/compute/hub.html" }, + ]; + + new Fetcher().resource("/api/").json().then((data) => { + const services = data.services || []; + const uris = new Set(services.map(s => s.base_uri)); + + const links = service_dashboards.filter(d => uris.has(d.base_uri)); + + if (links.length === 0) + { + nav.inner().style.display = "none"; + return; + } + + for (const link of links) + { + nav.tag("a").text(link.label).attr("href", link.href); + } + }).catch(() => { + nav.inner().style.display = "none"; + }); + } + set_title(...args) { super.set_title(...args); diff --git a/src/zenserver/frontend/html/zen.css b/src/zenserver/frontend/html/zen.css index 702bf9aa6..a80a1a4f6 100644 --- a/src/zenserver/frontend/html/zen.css +++ b/src/zenserver/frontend/html/zen.css @@ -80,6 +80,33 @@ input { } } +/* service nav -------------------------------------------------------------- */ + +#service_nav { + display: flex; + justify-content: center; + gap: 0.3em; + margin-bottom: 1.5em; + padding: 0.3em; + background-color: var(--theme_g3); + border: 1px solid var(--theme_g2); + border-radius: 0.4em; + + a { + padding: 0.3em 0.9em; + border-radius: 0.3em; + font-size: 0.85em; + color: var(--theme_g1); + text-decoration: none; + } + + a:hover { + background-color: var(--theme_p4); + color: var(--theme_g0); + text-decoration: none; + } +} + /* links -------------------------------------------------------------------- */ a { diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index bf0e294c5..a757cd594 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -845,7 +845,7 @@ HttpHubService::HttpHubService(std::filesystem::path HubBaseDir, std::filesystem Obj << "currentInstanceCount" << m_Impl->GetInstanceCount(); Obj << "maxInstanceCount" << m_Impl->GetMaxInstanceCount(); Obj << "instanceLimit" << m_Impl->GetInstanceLimit(); - Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); }, HttpVerb::kGet); } diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index d0a0db417..c63c618df 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -143,6 +143,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) ZEN_INFO("instantiating hub service"); m_HubService = std::make_unique(ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers"); m_HubService->SetNotificationEndpoint(ServerConfig.UpstreamNotificationEndpoint, ServerConfig.InstanceId); + + m_FrontendService = std::make_unique(m_ContentRoot, m_StatusService); } void @@ -159,6 +161,11 @@ ZenHubServer::RegisterServices(const ZenHubServerConfig& ServerConfig) { m_Http->RegisterService(*m_ApiService); } + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } } void diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index ac14362f0..4c56fdce5 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -2,6 +2,7 @@ #pragma once +#include "frontend/frontend.h" #include "zenserver.h" namespace cxxopts { @@ -81,8 +82,9 @@ private: std::filesystem::path m_ContentRoot; bool m_DebugOptionForcedCrash = false; - std::unique_ptr m_HubService; - std::unique_ptr m_ApiService; + std::unique_ptr m_HubService; + std::unique_ptr m_ApiService; + std::unique_ptr m_FrontendService; void InitializeState(const ZenHubServerConfig& ServerConfig); void InitializeServices(const ZenHubServerConfig& ServerConfig); diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index 3d81db656..bca26e87a 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -183,10 +183,15 @@ ZenStorageServer::RegisterServices() m_Http->RegisterService(*m_AdminService); + if (m_ApiService) + { + m_Http->RegisterService(*m_ApiService); + } + #if ZEN_WITH_COMPUTE_SERVICES - if (m_HttpFunctionService) + if (m_HttpComputeService) { - m_Http->RegisterService(*m_HttpFunctionService); + m_Http->RegisterService(*m_HttpComputeService); } #endif } @@ -279,8 +284,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions { ZEN_OTEL_SPAN("InitializeComputeService"); - m_HttpFunctionService = - std::make_unique(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); + m_HttpComputeService = + std::make_unique(*m_CidStore, m_StatsService, ServerOptions.DataDir / "functions"); } #endif @@ -316,6 +321,8 @@ ZenStorageServer::InitializeServices(const ZenStorageServerConfig& ServerOptions .AttachmentPassCount = ServerOptions.GcConfig.AttachmentPassCount}; m_GcScheduler.Initialize(GcConfig); + m_ApiService = std::make_unique(*m_Http); + // Create and register admin interface last to make sure all is properly initialized m_AdminService = std::make_unique( m_GcScheduler, @@ -832,7 +839,7 @@ ZenStorageServer::Cleanup() Flush(); #if ZEN_WITH_COMPUTE_SERVICES - m_HttpFunctionService.reset(); + m_HttpComputeService.reset(); #endif m_AdminService.reset(); diff --git a/src/zenserver/storage/zenstorageserver.h b/src/zenserver/storage/zenstorageserver.h index 456447a2a..5b163fc8e 100644 --- a/src/zenserver/storage/zenstorageserver.h +++ b/src/zenserver/storage/zenstorageserver.h @@ -25,7 +25,7 @@ #include "workspaces/httpworkspaces.h" #if ZEN_WITH_COMPUTE_SERVICES -# include +# include #endif namespace zen { @@ -93,7 +93,7 @@ private: std::unique_ptr m_ApiService; #if ZEN_WITH_COMPUTE_SERVICES - std::unique_ptr m_HttpFunctionService; + std::unique_ptr m_HttpComputeService; #endif }; diff --git a/src/zenserver/trace/tracerecorder.cpp b/src/zenserver/trace/tracerecorder.cpp new file mode 100644 index 000000000..5dec20e18 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.cpp @@ -0,0 +1,565 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "tracerecorder.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceSession : public std::enable_shared_from_this +{ + TraceSession(asio::ip::tcp::socket&& Socket, const std::filesystem::path& OutputDir) + : m_Socket(std::move(Socket)) + , m_OutputDir(OutputDir) + , m_SessionId(Oid::NewOid()) + { + try + { + m_RemoteAddress = m_Socket.remote_endpoint().address().to_string(); + } + catch (...) + { + m_RemoteAddress = "unknown"; + } + + ZEN_INFO("Trace session {} started from {}", m_SessionId, m_RemoteAddress); + } + + ~TraceSession() + { + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + + ZEN_INFO("Trace session {} ended, {} bytes recorded to '{}'", m_SessionId, m_TotalBytesRecorded, m_TraceFilePath); + } + + void Start() { ReadPreambleHeader(); } + + bool IsActive() const { return m_Socket.is_open(); } + + TraceSessionInfo GetInfo() const + { + TraceSessionInfo Info; + Info.SessionGuid = m_SessionGuid; + Info.TraceGuid = m_TraceGuid; + Info.ControlPort = m_ControlPort; + Info.TransportVersion = m_TransportVersion; + Info.ProtocolVersion = m_ProtocolVersion; + Info.RemoteAddress = m_RemoteAddress; + Info.BytesRecorded = m_TotalBytesRecorded; + Info.TraceFilePath = m_TraceFilePath; + return Info; + } + +private: + // Preamble format: + // [magic: 4 bytes][metadata_size: 2 bytes][metadata fields: variable][version: 2 bytes] + // + // Magic bytes: [0]=version_char ('2'-'9'), [1]='C', [2]='R', [3]='T' + // + // Metadata fields (repeated): + // [size: 1 byte][id: 1 byte][data: bytes] + // Field 0: ControlPort (uint16) + // Field 1: SessionGuid (16 bytes) + // Field 2: TraceGuid (16 bytes) + // + // Version: [transport: 1 byte][protocol: 1 byte] + + static constexpr size_t kMagicSize = 4; + static constexpr size_t kMetadataSizeFieldSize = 2; + static constexpr size_t kPreambleHeaderSize = kMagicSize + kMetadataSizeFieldSize; + static constexpr size_t kVersionSize = 2; + static constexpr size_t kPreambleBufferSize = 256; + static constexpr size_t kReadBufferSize = 64 * 1024; + + void ReadPreambleHeader() + { + auto Self = shared_from_this(); + + // Read the first 6 bytes: 4 magic + 2 metadata size + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer, kPreambleHeaderSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble header", Ec); + return; + } + + if (!ValidateMagic()) + { + ZEN_WARN("Trace session {}: invalid trace magic header", m_SessionId); + CloseSocket(); + return; + } + + ReadPreambleMetadata(); + }); + } + + bool ValidateMagic() + { + const uint8_t* Cursor = m_PreambleBuffer; + + // Validate magic: bytes are version, 'C', 'R', 'T' + if (Cursor[3] != 'T' || Cursor[2] != 'R' || Cursor[1] != 'C') + { + return false; + } + + if (Cursor[0] < '2' || Cursor[0] > '9') + { + return false; + } + + // Extract the metadata fields size (does not include the trailing version bytes) + std::memcpy(&m_MetadataFieldsSize, Cursor + kMagicSize, sizeof(m_MetadataFieldsSize)); + + if (m_MetadataFieldsSize + kVersionSize > kPreambleBufferSize - kPreambleHeaderSize) + { + return false; + } + + return true; + } + + void ReadPreambleMetadata() + { + auto Self = shared_from_this(); + size_t ReadSize = m_MetadataFieldsSize + kVersionSize; + + // Read metadata fields + 2 version bytes + asio::async_read(m_Socket, + asio::buffer(m_PreambleBuffer + kPreambleHeaderSize, ReadSize), + [this, Self](const asio::error_code& Ec, std::size_t /*BytesRead*/) { + if (Ec) + { + HandleReadError("preamble metadata", Ec); + return; + } + + if (!ParseMetadata()) + { + ZEN_WARN("Trace session {}: malformed trace metadata", m_SessionId); + CloseSocket(); + return; + } + + if (!CreateTraceFile()) + { + CloseSocket(); + return; + } + + // Write the full preamble to the trace file so it remains a valid .utrace + size_t PreambleSize = kPreambleHeaderSize + m_MetadataFieldsSize + kVersionSize; + std::error_code WriteEc; + m_TraceFile.Write(m_PreambleBuffer, PreambleSize, 0, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: failed to write preamble: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded = PreambleSize; + + ZEN_INFO("Trace session {}: metadata - TransportV{} ProtocolV{} ControlPort:{} SessionGuid:{} TraceGuid:{}", + m_SessionId, + m_TransportVersion, + m_ProtocolVersion, + m_ControlPort, + m_SessionGuid, + m_TraceGuid); + + // Begin streaming trace data to disk + ReadMore(); + }); + } + + bool ParseMetadata() + { + const uint8_t* Cursor = m_PreambleBuffer + kPreambleHeaderSize; + int32_t Remaining = static_cast(m_MetadataFieldsSize); + + while (Remaining >= 2) + { + uint8_t FieldSize = Cursor[0]; + uint8_t FieldId = Cursor[1]; + Cursor += 2; + Remaining -= 2; + + if (Remaining < FieldSize) + { + return false; + } + + switch (FieldId) + { + case 0: // ControlPort + if (FieldSize >= sizeof(uint16_t)) + { + std::memcpy(&m_ControlPort, Cursor, sizeof(uint16_t)); + } + break; + case 1: // SessionGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_SessionGuid, Cursor, sizeof(Guid)); + } + break; + case 2: // TraceGuid + if (FieldSize >= sizeof(Guid)) + { + std::memcpy(&m_TraceGuid, Cursor, sizeof(Guid)); + } + break; + } + + Cursor += FieldSize; + Remaining -= FieldSize; + } + + // Metadata should be fully consumed + if (Remaining != 0) + { + return false; + } + + // Version bytes follow immediately after the metadata fields + const uint8_t* VersionPtr = m_PreambleBuffer + kPreambleHeaderSize + m_MetadataFieldsSize; + m_TransportVersion = VersionPtr[0]; + m_ProtocolVersion = VersionPtr[1]; + + return true; + } + + bool CreateTraceFile() + { + m_TraceFilePath = m_OutputDir / fmt::format("{}.utrace", m_SessionId); + + try + { + m_TraceFile.Open(m_TraceFilePath, BasicFile::Mode::kTruncate); + ZEN_INFO("Trace session {} writing to '{}'", m_SessionId, m_TraceFilePath); + return true; + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Trace session {}: failed to create trace file '{}': {}", m_SessionId, m_TraceFilePath, Ex.what()); + return false; + } + } + + void ReadMore() + { + auto Self = shared_from_this(); + + m_Socket.async_read_some(asio::buffer(m_ReadBuffer, kReadBufferSize), + [this, Self](const asio::error_code& Ec, std::size_t BytesRead) { + if (!Ec) + { + if (BytesRead > 0 && m_TraceFile.IsOpen()) + { + std::error_code WriteEc; + const uint64_t FileOffset = m_TotalBytesRecorded; + m_TraceFile.Write(m_ReadBuffer, BytesRead, FileOffset, WriteEc); + + if (WriteEc) + { + ZEN_ERROR("Trace session {}: write error: {}", m_SessionId, WriteEc.message()); + CloseSocket(); + return; + } + + m_TotalBytesRecorded += BytesRead; + } + + ReadMore(); + } + else if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {} connection closed by peer", m_SessionId); + CloseSocket(); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {} operation aborted", m_SessionId); + } + else + { + ZEN_WARN("Trace session {} read error: {}", m_SessionId, Ec.message()); + CloseSocket(); + } + }); + } + + void HandleReadError(const char* Phase, const asio::error_code& Ec) + { + if (Ec == asio::error::eof) + { + ZEN_DEBUG("Trace session {}: connection closed during {}", m_SessionId, Phase); + } + else if (Ec == asio::error::operation_aborted) + { + ZEN_DEBUG("Trace session {}: operation aborted during {}", m_SessionId, Phase); + } + else + { + ZEN_WARN("Trace session {}: error during {}: {}", m_SessionId, Phase, Ec.message()); + } + + CloseSocket(); + } + + void CloseSocket() + { + std::error_code Ec; + m_Socket.close(Ec); + + if (m_TraceFile.IsOpen()) + { + m_TraceFile.Close(); + } + } + + asio::ip::tcp::socket m_Socket; + std::filesystem::path m_OutputDir; + std::filesystem::path m_TraceFilePath; + BasicFile m_TraceFile; + Oid m_SessionId; + std::string m_RemoteAddress; + + // Preamble parsing + uint8_t m_PreambleBuffer[kPreambleBufferSize] = {}; + uint16_t m_MetadataFieldsSize = 0; + + // Extracted metadata + Guid m_SessionGuid{}; + Guid m_TraceGuid{}; + uint16_t m_ControlPort = 0; + uint8_t m_TransportVersion = 0; + uint8_t m_ProtocolVersion = 0; + + // Streaming + uint8_t m_ReadBuffer[kReadBufferSize]; + uint64_t m_TotalBytesRecorded = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct TraceRecorder::Impl +{ + Impl() : m_IoContext(), m_Acceptor(m_IoContext) {} + + ~Impl() { Shutdown(); } + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) + { + std::lock_guard Lock(m_Mutex); + + if (m_IsRunning) + { + ZEN_WARN("TraceRecorder already initialized"); + return; + } + + m_OutputDir = OutputDir; + + try + { + // Create output directory if it doesn't exist + CreateDirectories(m_OutputDir); + + // Configure acceptor + m_Acceptor.open(asio::ip::tcp::v4()); + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::tcp::v4(), InPort)); + m_Acceptor.listen(); + + m_Port = m_Acceptor.local_endpoint().port(); + + ZEN_INFO("TraceRecorder listening on port {}, output directory: '{}'", m_Port, m_OutputDir); + + m_IsRunning = true; + + // Start accepting connections + StartAccept(); + + // Start IO thread + m_IoThread = std::thread([this]() { + try + { + m_IoContext.run(); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("TraceRecorder IO thread exception: {}", Ex.what()); + } + }); + } + catch (const std::exception& Ex) + { + ZEN_ERROR("Failed to initialize TraceRecorder: {}", Ex.what()); + m_IsRunning = false; + throw; + } + } + + void Shutdown() + { + std::lock_guard Lock(m_Mutex); + + if (!m_IsRunning) + { + return; + } + + ZEN_INFO("TraceRecorder shutting down"); + + m_IsRunning = false; + + std::error_code Ec; + m_Acceptor.close(Ec); + + m_IoContext.stop(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + + { + std::lock_guard SessionLock(m_SessionsMutex); + m_Sessions.clear(); + } + + ZEN_INFO("TraceRecorder shutdown complete"); + } + + bool IsRunning() const { return m_IsRunning; } + + uint16_t GetPort() const { return m_Port; } + + std::vector GetActiveSessions() const + { + std::lock_guard Lock(m_SessionsMutex); + + std::vector Result; + for (const auto& WeakSession : m_Sessions) + { + if (auto Session = WeakSession.lock()) + { + if (Session->IsActive()) + { + Result.push_back(Session->GetInfo()); + } + } + } + return Result; + } + +private: + void StartAccept() + { + auto Socket = std::make_shared(m_IoContext); + + m_Acceptor.async_accept(*Socket, [this, Socket](const asio::error_code& Ec) { + if (!Ec) + { + auto Session = std::make_shared(std::move(*Socket), m_OutputDir); + + { + std::lock_guard Lock(m_SessionsMutex); + + // Prune expired sessions while adding the new one + std::erase_if(m_Sessions, [](const std::weak_ptr& Wp) { return Wp.expired(); }); + m_Sessions.push_back(Session); + } + + Session->Start(); + } + else if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("Accept error: {}", Ec.message()); + } + + // Continue accepting if still running + if (m_IsRunning) + { + StartAccept(); + } + }); + } + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + std::thread m_IoThread; + std::filesystem::path m_OutputDir; + std::mutex m_Mutex; + std::atomic m_IsRunning{false}; + uint16_t m_Port = 0; + + mutable std::mutex m_SessionsMutex; + std::vector> m_Sessions; +}; + +//////////////////////////////////////////////////////////////////////////////// + +TraceRecorder::TraceRecorder() : m_Impl(std::make_unique()) +{ +} + +TraceRecorder::~TraceRecorder() +{ + Shutdown(); +} + +void +TraceRecorder::Initialize(uint16_t InPort, const std::filesystem::path& OutputDir) +{ + m_Impl->Initialize(InPort, OutputDir); +} + +void +TraceRecorder::Shutdown() +{ + m_Impl->Shutdown(); +} + +bool +TraceRecorder::IsRunning() const +{ + return m_Impl->IsRunning(); +} + +uint16_t +TraceRecorder::GetPort() const +{ + return m_Impl->GetPort(); +} + +std::vector +TraceRecorder::GetActiveSessions() const +{ + return m_Impl->GetActiveSessions(); +} + +} // namespace zen diff --git a/src/zenserver/trace/tracerecorder.h b/src/zenserver/trace/tracerecorder.h new file mode 100644 index 000000000..48857aec8 --- /dev/null +++ b/src/zenserver/trace/tracerecorder.h @@ -0,0 +1,46 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace zen { + +struct TraceSessionInfo +{ + Guid SessionGuid{}; + Guid TraceGuid{}; + uint16_t ControlPort = 0; + uint8_t TransportVersion = 0; + uint8_t ProtocolVersion = 0; + std::string RemoteAddress; + uint64_t BytesRecorded = 0; + std::filesystem::path TraceFilePath; +}; + +class TraceRecorder +{ +public: + TraceRecorder(); + ~TraceRecorder(); + + void Initialize(uint16_t InPort, const std::filesystem::path& OutputDir); + void Shutdown(); + + bool IsRunning() const; + uint16_t GetPort() const; + + std::vector GetActiveSessions() const; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen \ No newline at end of file diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 9ab51beb2..915b6a3b1 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -27,6 +27,7 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("nomad") if has_config("zenmimalloc") then add_packages("mimalloc") @@ -36,6 +37,14 @@ target("zenserver") add_packages("sentry-native") end + if has_config("zenhorde") then + add_deps("zenhorde") + end + + if has_config("zennomad") then + add_deps("zennomad") + end + if is_mode("release") then set_optimize("fastest") end @@ -145,4 +154,14 @@ target("zenserver") end copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + + local nomad_pkg = target:pkg("nomad") + if nomad_pkg then + local installdir = nomad_pkg:installdir() + local nomad_bin = "nomad" + if is_plat("windows") then + nomad_bin = "nomad.exe" + end + copy_if_newer(path.join(installdir, "bin", nomad_bin), path.join(target:targetdir(), nomad_bin), nomad_bin) + end end) diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp index 67fbef532..509629739 100644 --- a/src/zentest-appstub/zentest-appstub.cpp +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -106,6 +106,11 @@ DescribeFunctions() << "Reverse"sv; Versions << "Version"sv << Guid::FromString("31313131-3131-3131-3131-313131313131"sv); Versions.EndObject(); + Versions.BeginObject(); + Versions << "Name"sv + << "Sleep"sv; + Versions << "Version"sv << Guid::FromString("88888888-8888-8888-8888-888888888888"sv); + Versions.EndObject(); Versions.EndArray(); return Versions.Save(); @@ -190,6 +195,12 @@ ExecuteFunction(CbObject Action, ContentResolver ChunkResolver) { return Apply(NullFunction); } + else if (Function == "Sleep"sv) + { + uint64_t SleepTimeMs = Action["Constants"sv].AsObjectView()["SleepTimeMs"sv].AsUInt64(); + zen::Sleep(static_cast(SleepTimeMs)); + return Apply(IdentityFunction); + } else { return {}; diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 7dc68c126..5f74fa82b 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include #include diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index e81b154e8..2a8617162 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -84,6 +84,7 @@ struct ZenServerInstance { kStorageServer, // default kHubServer, + kComputeServer, }; ZenServerInstance(ZenServerEnvironment& TestEnvironment, ServerMode Mode = ServerMode::kStorageServer); diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index e127a92d7..b09c2d89a 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -787,6 +787,8 @@ ToString(ZenServerInstance::ServerMode Mode) return "storage"sv; case ZenServerInstance::ServerMode::kHubServer: return "hub"sv; + case ZenServerInstance::ServerMode::kComputeServer: + return "compute"sv; default: return "invalid"sv; } @@ -808,6 +810,10 @@ ZenServerInstance::SpawnServerInternal(int ChildId, std::string_view ServerArgs, { CommandLine << " hub"; } + else if (m_ServerMode == ServerMode::kComputeServer) + { + CommandLine << " compute"; + } CommandLine << " --child-id " << ChildEventName; -- cgit v1.2.3 From 6926c04dc4d7c5c0f0310b66c17c9a4e94d2e341 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Wed, 4 Mar 2026 16:07:14 +0100 Subject: more feedback during auth option parsing (#806) * remove stray std::unique_ptr Auth; causing crashes * add more feedback during parsing of auth options --- src/zen/authutils.cpp | 80 ++++++++++++++++++++++-------------- src/zen/cmds/builds_cmd.cpp | 2 - src/zencore/include/zencore/string.h | 2 + src/zencore/string.cpp | 56 +++++++++++++++++++++---- 4 files changed, 99 insertions(+), 41 deletions(-) (limited to 'src') diff --git a/src/zen/authutils.cpp b/src/zen/authutils.cpp index 16427acf5..23ac70965 100644 --- a/src/zen/authutils.cpp +++ b/src/zen/authutils.cpp @@ -154,21 +154,34 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, ZEN_ASSERT(!SystemRootDir.empty()); if (!Auth) { - if (m_EncryptionKey.empty()) + static const std::string_view DefaultEncryptionKey("abcdefghijklmnopqrstuvxyz0123456"); + static const std::string_view DefaultEncryptionIV("0123456789abcdef"); + if (m_EncryptionKey.empty() && m_EncryptionIV.empty()) { - m_EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456"; + m_EncryptionKey = DefaultEncryptionKey; + m_EncryptionIV = DefaultEncryptionIV; if (!Quiet) { - ZEN_CONSOLE_WARN("Using default encryption key"); + ZEN_CONSOLE_WARN("Auth: Using default encryption key and initialization vector for auth storage"); } } - - if (m_EncryptionIV.empty()) + else { - m_EncryptionIV = "0123456789abcdef"; - if (!Quiet) + if (m_EncryptionKey.empty()) + { + m_EncryptionKey = DefaultEncryptionKey; + if (!Quiet) + { + ZEN_CONSOLE_WARN("Auth: Using default encryption key for auth storage"); + } + } + if (m_EncryptionIV.empty()) { - ZEN_CONSOLE_WARN("Using default encryption initialization vector"); + m_EncryptionIV = DefaultEncryptionIV; + if (!Quiet) + { + ZEN_CONSOLE_WARN("Auth: Using default encryption initialization vector for auth storage"); + } } } @@ -187,9 +200,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, { ExtendableStringBuilder<128> SB; SB << "\n RootDirectory: " << AuthMgrConfig.RootDirectory.string(); - SB << "\n EncryptionKey: " << m_EncryptionKey; - SB << "\n EncryptionIV: " << m_EncryptionIV; - ZEN_CONSOLE("Creating auth manager with:{}", SB.ToString()); + SB << "\n EncryptionKey: " << HideSensitiveString(m_EncryptionKey); + SB << "\n EncryptionIV: " << HideSensitiveString(m_EncryptionIV); + ZEN_CONSOLE("Auth: Creating auth manager with:{}", SB.ToString()); } Auth = AuthMgr::Create(AuthMgrConfig); } @@ -204,13 +217,18 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, ExtendableStringBuilder<128> SB; SB << "\n Name: " << ProviderName; SB << "\n Url: " << m_OpenIdProviderUrl; - SB << "\n ClientId: " << m_OpenIdClientId; - ZEN_CONSOLE("Adding openid auth provider:{}", SB.ToString()); + SB << "\n ClientId: " << HideSensitiveString(m_OpenIdClientId); + ZEN_CONSOLE("Auth: Adding Open ID auth provider:{}", SB.ToString()); } Auth->AddOpenIdProvider({.Name = ProviderName, .Url = m_OpenIdProviderUrl, .ClientId = m_OpenIdClientId}); if (!m_OpenIdRefreshToken.empty()) { - ZEN_CONSOLE("Adding open id refresh token {} to provider {}", m_OpenIdRefreshToken, ProviderName); + if (!Quiet) + { + ZEN_CONSOLE("Auth: Adding open id refresh token {} to provider {}", + HideSensitiveString(m_OpenIdRefreshToken), + ProviderName); + } Auth->AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = m_OpenIdRefreshToken}); } } @@ -225,9 +243,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, if (!m_AccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Adding static auth token: {}", m_AccessToken); + ZEN_CONSOLE("Auth: Using static auth token: {}", HideSensitiveString(m_AccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(m_AccessToken); } @@ -237,9 +255,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, std::string ResolvedAccessToken = ReadAccessTokenFromJsonFile(m_AccessTokenPath); if (!ResolvedAccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Adding static auth token from {}: {}", m_AccessTokenPath, ResolvedAccessToken); + ZEN_CONSOLE("Auth: Adding static auth token from {}: {}", m_AccessTokenPath, HideSensitiveString(ResolvedAccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken); } @@ -250,9 +268,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, { ExtendableStringBuilder<128> SB; SB << "\n Url: " << m_OAuthUrl; - SB << "\n ClientId: " << m_OAuthClientId; - SB << "\n ClientSecret: " << m_OAuthClientSecret; - ZEN_CONSOLE("Adding oauth provider:{}", SB.ToString()); + SB << "\n ClientId: " << HideSensitiveString(m_OAuthClientId); + SB << "\n ClientSecret: " << HideSensitiveString(m_OAuthClientSecret); + ZEN_CONSOLE("Auth: Adding oauth provider:{}", SB.ToString()); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOAuthClientCredentials( {.Url = m_OAuthUrl, .ClientId = m_OAuthClientId, .ClientSecret = m_OAuthClientSecret}); @@ -260,25 +278,27 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, else if (!m_OpenIdProviderName.empty()) { CreateAuthMgr(); - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using openid provider: {}", m_OpenIdProviderName); + ZEN_CONSOLE("Auth: Using OpenId provider: {}", m_OpenIdProviderName); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOpenIdProvider(*Auth, m_OpenIdProviderName); } else if (std::string ResolvedAccessToken = GetEnvAccessToken(m_AccessTokenEnv); !ResolvedAccessToken.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using environment variable '{}' as access token '{}'", m_AccessTokenEnv, ResolvedAccessToken); + ZEN_CONSOLE("Auth: Resolved environment variable '{}' to access token '{}'", + m_AccessTokenEnv, + HideSensitiveString(ResolvedAccessToken)); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken); } - else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); !OidcTokenExePath.empty()) + else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); OidcTokenExePath.empty()) { - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Running oidctoken exe from path '{}'", m_OidcTokenAuthExecutablePath); + ZEN_CONSOLE("Auth: Using oidctoken exe from path '{}'", OidcTokenExePath); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromOidcTokenExecutable(OidcTokenExePath, HostUrl, Quiet, m_OidcTokenUnattended, Hidden); @@ -291,9 +311,9 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, if (!ClientSettings.AccessTokenProvider) { CreateAuthMgr(); - if (Verbose) + if (!Quiet) { - ZEN_CONSOLE("Using default openid provider"); + ZEN_CONSOLE("Auth: Using default Open ID provider"); } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(*Auth); } diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index ffdc5fe48..0722e9714 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -2808,8 +2808,6 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) .Verbose = m_VerboseHttp, .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}; - std::unique_ptr Auth; - std::string StorageDescription; std::string CacheDescription; diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index 5a12ba5d2..250eb9f56 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -1265,6 +1265,8 @@ private: uint64_t LoMask, HiMask; }; +std::string HideSensitiveString(std::string_view String); + ////////////////////////////////////////////////////////////////////////// void string_forcelink(); // internal diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index 27635a86c..3d0451e27 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -539,10 +539,33 @@ UrlDecode(std::string_view InUrl) return std::string(Url.ToView()); } -////////////////////////////////////////////////////////////////////////// -// -// Unit tests -// +std::string +HideSensitiveString(std::string_view String) +{ + const size_t Length = String.length(); + const size_t SourceLength = Length > 16 ? 4 : 0; + const size_t PadLength = Min(Length - SourceLength, 4u); + const bool AddEllipsis = (SourceLength + PadLength) < Length; + StringBuilder<16> SB; + if (SourceLength > 0) + { + SB << String.substr(0, SourceLength); + } + if (PadLength > 0) + { + SB << std::string(PadLength, 'X'); + } + if (AddEllipsis) + { + SB << "..."; + } + return SB.ToString(); +}; + + ////////////////////////////////////////////////////////////////////////// + // + // Unit tests + // #if ZEN_WITH_TESTS @@ -814,11 +837,6 @@ TEST_CASE("niceNum") } } -void -string_forcelink() -{ -} - TEST_CASE("StringBuilder") { StringBuilder<64> sb; @@ -1224,8 +1242,28 @@ TEST_CASE("string") } } +TEST_CASE("hidesensitivestring") +{ + using namespace std::literals; + + CHECK_EQ(HideSensitiveString(""sv), ""sv); + CHECK_EQ(HideSensitiveString("A"sv), "X"sv); + CHECK_EQ(HideSensitiveString("ABCD"sv), "XXXX"sv); + CHECK_EQ(HideSensitiveString("ABCDE"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGH"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOP"sv), "XXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQ"sv), "ABCDXXXX..."sv); + CHECK_EQ(HideSensitiveString("ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"sv), "ABCDXXXX..."sv); + CHECK_EQ(HideSensitiveString("1234567890123456789"sv), "1234XXXX..."sv); +} + TEST_SUITE_END(); +void +string_forcelink() +{ +} + #endif } // namespace zen -- cgit v1.2.3 From 1f83b48a20bf90f41e18867620c5774f3be6280d Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 4 Mar 2026 17:23:36 +0100 Subject: Fixing various compiler issues (#807) Compile fixes for various versions of gcc,clang (non-UE) --- src/zencore/filesystem.cpp | 11 ++++++----- src/zencore/include/zencore/hashutils.h | 1 + src/zencore/xmake.lua | 7 ++++++- src/zenserver/xmake.lua | 6 ++++++ 4 files changed, 19 insertions(+), 6 deletions(-) (limited to 'src') diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 9885b2ada..8ed63565c 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -194,7 +194,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) FindClose(hFind); } - return true; + return Success; } bool @@ -1022,7 +1022,7 @@ TryCloneFile(const std::filesystem::path& FromPath, const std::filesystem::path& return false; } fchmod(ToFd, 0666); - ScopedFd $To = { FromFd }; + ScopedFd $To = { ToFd }; ioctl(ToFd, FICLONE, FromFd); @@ -1112,7 +1112,8 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP size_t FileSizeBytes = Stat.st_size; - fchown(ToFd, Stat.st_uid, Stat.st_gid); + int $Ignore = fchown(ToFd, Stat.st_uid, Stat.st_gid); + ZEN_UNUSED($Ignore); // What's the appropriate error handling here? // Copy impl const size_t BufferSize = Min(FileSizeBytes, 64u << 10); @@ -1398,7 +1399,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer const uint64_t ChunkSize = Min(WriteSize, uint64_t(2) * 1024 * 1024 * 1024); #if ZEN_PLATFORM_WINDOWS - hRes = Outfile.Write(DataPtr, gsl::narrow_cast(WriteSize)); + hRes = Outfile.Write(DataPtr, gsl::narrow_cast(ChunkSize)); if (FAILED(hRes)) { Outfile.Close(); @@ -1407,7 +1408,7 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str()); } #else - if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize)) + if (write(Fd, DataPtr, ChunkSize) != int64_t(ChunkSize)) { close(Fd); std::error_code DummyEc; diff --git a/src/zencore/include/zencore/hashutils.h b/src/zencore/include/zencore/hashutils.h index 6b9902b3a..8abfd4b6e 100644 --- a/src/zencore/include/zencore/hashutils.h +++ b/src/zencore/include/zencore/hashutils.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include namespace zen { diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index 2f81b7ec8..ab842f6ed 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -14,7 +14,12 @@ target('zencore') end) set_configdir("include/zencore") add_files("**.cpp") - add_files("trace.cpp", {unity_ignored = true }) + if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then + -- GCC false positives in thirdparty trace.h (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) + add_files("trace.cpp", {unity_ignored = true, force = {cxxflags = {"-Wno-stringop-overread", "-Wno-dangling-pointer"}} }) + else + add_files("trace.cpp", {unity_ignored = true }) + end add_files("testing.cpp", {unity_ignored = true }) if has_config("zenrpmalloc") then diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 915b6a3b1..7a9031782 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -19,6 +19,12 @@ target("zenserver") add_files("**.cpp") add_files("frontend/*.zip") add_files("zenserver.cpp", {unity_ignored = true }) + + if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then + -- GCC false positives in deeply inlined code (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) + add_files("storage/projectstore/httpprojectstore.cpp", {force = {cxxflags = "-Wno-stringop-overflow"} }) + add_files("storage/storageconfig.cpp", {force = {cxxflags = "-Wno-array-bounds"} }) + end add_includedirs(".") set_symbols("debug") -- cgit v1.2.3 From d8940b27c8a5c070c3b48ca9e575929df8d1d888 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Thu, 5 Mar 2026 00:08:19 +0100 Subject: added TEST_SUITE_BEGIN/END around some TEST_CASEs which didn't have them (#809) * added TEST_SUITE_BEGIN/END around some TEST_CASEs which didn't have them * fixed some stats issues * ScopedSpan should Initialize * annotated classes in stats.h with some documentation comments --- src/zencompute/cloudmetadata.cpp | 4 + src/zencompute/runners/deferreddeleter.cpp | 4 + src/zencore/xxhash.cpp | 4 + .../projectstore/remoteprojectstore.cpp | 4 + src/zenserver-test/logging-tests.cpp | 4 + src/zenserver-test/nomad-tests.cpp | 4 + src/zentelemetry/include/zentelemetry/otlptrace.h | 9 +- src/zentelemetry/include/zentelemetry/stats.h | 202 +++++++++++++++------ src/zentelemetry/stats.cpp | 2 +- 9 files changed, 180 insertions(+), 57 deletions(-) (limited to 'src') diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp index b3b3210d9..65bac895f 100644 --- a/src/zencompute/cloudmetadata.cpp +++ b/src/zencompute/cloudmetadata.cpp @@ -622,6 +622,8 @@ CloudMetadata::PollGCPTermination() namespace zen::compute { +TEST_SUITE_BEGIN("compute.cloudmetadata"); + // --------------------------------------------------------------------------- // Test helper — spins up a local ASIO HTTP server hosting a MockImdsService // --------------------------------------------------------------------------- @@ -1000,6 +1002,8 @@ TEST_CASE("cloudmetadata.sentinel_files") } } +TEST_SUITE_END(); + void cloudmetadata_forcelink() { diff --git a/src/zencompute/runners/deferreddeleter.cpp b/src/zencompute/runners/deferreddeleter.cpp index 00977d9fa..4fad2cf70 100644 --- a/src/zencompute/runners/deferreddeleter.cpp +++ b/src/zencompute/runners/deferreddeleter.cpp @@ -231,6 +231,8 @@ deferreddeleter_forcelink() namespace zen::compute { +TEST_SUITE_BEGIN("compute.deferreddeleter"); + TEST_CASE("DeferredDirectoryDeleter.DeletesSingleDirectory") { ScopedTemporaryDirectory TempDir; @@ -331,6 +333,8 @@ TEST_CASE("DeferredDirectoryDeleter.MarkReadyShortensDeferral") CHECK(!std::filesystem::exists(Dir)); } +TEST_SUITE_END(); + } // namespace zen::compute #endif // ZEN_WITH_TESTS && ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp index 6d1050531..88a48dd68 100644 --- a/src/zencore/xxhash.cpp +++ b/src/zencore/xxhash.cpp @@ -59,6 +59,8 @@ xxhash_forcelink() { } +TEST_SUITE_BEGIN("core.xxhash"); + TEST_CASE("XXH3_128") { using namespace std::literals; @@ -96,6 +98,8 @@ TEST_CASE("XXH3_128") } } +TEST_SUITE_END(); + #endif } // namespace zen diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 570025b6d..78f6014df 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -4240,6 +4240,8 @@ namespace projectstore_testutils { } // namespace projectstore_testutils +TEST_SUITE_BEGIN("remotestore.projectstore"); + struct ExportForceDisableBlocksTrue_ForceTempBlocksFalse { static const bool ForceDisableBlocks = true; @@ -4395,6 +4397,8 @@ TEST_CASE_TEMPLATE("project.store.export", CHECK(ImportForceCleanResult.ErrorCode == 0); } +TEST_SUITE_END(); + #endif // ZEN_WITH_TESTS void diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp index fe39e14c0..f284f0371 100644 --- a/src/zenserver-test/logging-tests.cpp +++ b/src/zenserver-test/logging-tests.cpp @@ -15,6 +15,8 @@ namespace zen::tests { using namespace std::literals; +TEST_SUITE_BEGIN("server.logging"); + ////////////////////////////////////////////////////////////////////////// static bool @@ -252,6 +254,8 @@ TEST_CASE("logging.level.off_specific_logger") CHECK_MESSAGE(!LogContains(HttpLog, "server session id"), HttpLog); } +TEST_SUITE_END(); + } // namespace zen::tests #endif diff --git a/src/zenserver-test/nomad-tests.cpp b/src/zenserver-test/nomad-tests.cpp index 6eb99bc3a..f8f5a9a30 100644 --- a/src/zenserver-test/nomad-tests.cpp +++ b/src/zenserver-test/nomad-tests.cpp @@ -17,6 +17,8 @@ namespace zen::tests::nomad_tests { using namespace std::literals; +TEST_SUITE_BEGIN("server.nomad"); + TEST_CASE("nomad.client.lifecycle" * doctest::skip()) { zen::nomad::NomadProcess NomadProc; @@ -122,5 +124,7 @@ TEST_CASE("nomad.provisioner.integration" * doctest::skip()) NomadProc.StopNomadAgent(); } +TEST_SUITE_END(); + } // namespace zen::tests::nomad_tests #endif diff --git a/src/zentelemetry/include/zentelemetry/otlptrace.h b/src/zentelemetry/include/zentelemetry/otlptrace.h index 49dd90358..95718af55 100644 --- a/src/zentelemetry/include/zentelemetry/otlptrace.h +++ b/src/zentelemetry/include/zentelemetry/otlptrace.h @@ -317,6 +317,7 @@ public: ExtendableStringBuilder<128> NameBuilder; NamingFunction(NameBuilder); + Initialize(NameBuilder); } /** Construct a new span with a naming function AND initializer function @@ -350,7 +351,13 @@ public: // Execute a function with the span pointer if valid. This can // be used to add attributes or events to the span after creation - inline void WithSpan(auto Func) const { Func(*m_Span); } + inline void WithSpan(auto Func) const + { + if (m_Span) + { + Func(*m_Span); + } + } private: void Initialize(std::string_view Name); diff --git a/src/zentelemetry/include/zentelemetry/stats.h b/src/zentelemetry/include/zentelemetry/stats.h index 3e67bac1c..d58846a3b 100644 --- a/src/zentelemetry/include/zentelemetry/stats.h +++ b/src/zentelemetry/include/zentelemetry/stats.h @@ -16,6 +16,11 @@ class CbObjectWriter; namespace zen::metrics { +/** A single atomic value that can be set and read at any time. + * + * Useful for point-in-time readings such as queue depth, active connection count, + * or any value where only the current state matters rather than history. + */ template class Gauge { @@ -29,12 +34,12 @@ private: std::atomic m_Value; }; -/** Stats counter +/** Monotonically increasing (or decreasing) counter. * - * A counter is modified by adding or subtracting a value from a current value. - * This would typically be used to track number of requests in flight, number - * of active jobs etc + * Suitable for tracking quantities that go up and down over time, such as + * requests in flight or active jobs. All operations are lock-free via atomics. * + * Unlike a Meter, a Counter does not track rates — it only records a running total. */ class Counter { @@ -50,34 +55,56 @@ private: std::atomic m_count{0}; }; -/** Exponential Weighted Moving Average - - This is very raw, to use as little state as possible. If we - want to use this more broadly in user code we should perhaps - add a more user-friendly wrapper +/** Low-level exponential weighted moving average. + * + * Tracks a smoothed rate using the standard EWMA recurrence: + * + * rate = rate + alpha * (instantRate - rate) + * + * where instantRate = Count / Interval. The alpha value controls how quickly + * the average responds to changes — higher alpha means more weight on recent + * samples. Typical alphas are derived from a decay half-life (e.g. 1, 5, 15 + * minutes) and a fixed tick interval. + * + * This class is intentionally minimal to keep per-instance state to a single + * atomic double. See Meter for a more convenient wrapper. */ - class RawEWMA { public: - /// - /// Update EWMA with new measure - /// - /// Smoothing factor (between 0 and 1) - /// Elapsed time since last - /// Value - /// Whether this is the first update or not - void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate); + /** Update the EWMA with a new observation. + * + * @param Alpha Smoothing factor in (0, 1). Smaller values give a + * slower-moving average; larger values track recent + * changes more aggressively. + * @param Interval Elapsed hi-freq timer ticks since the last Tick call. + * Used to compute the instantaneous rate as Count/Interval. + * @param Count Number of events observed during this interval. + * @param IsInitialUpdate True on the very first call: seeds the rate directly + * from the instantaneous rate rather than blending it in. + */ + void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate); + + /** Returns the current smoothed rate in events per second. */ double Rate() const; private: std::atomic m_Rate = 0; }; -/// -/// Tracks rate of events over time (i.e requests/sec), using -/// exponential moving averages -/// +/** Tracks the rate of events over time using exponential moving averages. + * + * Maintains three EWMA windows (1, 5, 15 minutes) in addition to a simple + * mean rate computed from the total count and elapsed wall time since + * construction. This mirrors the load-average conventions familiar from Unix. + * + * Rate updates are batched: Mark() accumulates a pending count and the EWMA + * is only advanced every ~5 seconds (controlled by kTickIntervalInSeconds), + * keeping contention low even under heavy call rates. Rates are returned in + * events per second. + * + * All operations are thread-safe via lock-free atomics. + */ class Meter { public: @@ -85,18 +112,18 @@ public: ~Meter(); inline uint64_t Count() const { return m_TotalCount; } - double Rate1(); // One-minute rate - double Rate5(); // Five-minute rate - double Rate15(); // Fifteen-minute rate - double MeanRate() const; // Mean rate since instantiation of this meter + double Rate1(); // One-minute EWMA rate (events/sec) + double Rate5(); // Five-minute EWMA rate (events/sec) + double Rate15(); // Fifteen-minute EWMA rate (events/sec) + double MeanRate() const; // Mean rate since instantiation (events/sec) void Mark(uint64_t Count = 1); // Register one or more events private: std::atomic m_TotalCount{0}; // Accumulator counting number of marks since beginning - std::atomic m_PendingCount{0}; // Pending EWMA update accumulator - std::atomic m_StartTick{0}; // Time this was instantiated (for mean) - std::atomic m_LastTick{0}; // Timestamp of last EWMA tick - std::atomic m_Remainder{0}; // Tracks the "modulo" of tick time + std::atomic m_PendingCount{0}; // Pending EWMA update accumulator; drained on each tick + std::atomic m_StartTick{0}; // Hi-freq timer value at construction (for MeanRate) + std::atomic m_LastTick{0}; // Hi-freq timer value of the last EWMA tick + std::atomic m_Remainder{0}; // Accumulated ticks not yet consumed by EWMA updates bool m_IsFirstTick = true; RawEWMA m_RateM1; RawEWMA m_RateM5; @@ -106,7 +133,14 @@ private: void Tick(); }; -/** Moment-in-time snapshot of a distribution +/** Immutable sorted snapshot of a reservoir sample. + * + * Constructed from a vector of sampled values which are sorted on construction. + * Percentiles are computed on demand via linear interpolation between adjacent + * sorted values, following the standard R-7 quantile method. + * + * Because this is a copy of the reservoir at a point in time, it can be held + * and queried without holding any locks on the source UniformSample. */ class SampleSnapshot { @@ -128,12 +162,19 @@ private: std::vector m_Values; }; -/** Randomly selects samples from a stream. Uses Vitter's - Algorithm R to produce a statistically representative sample. - - http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir +/** Reservoir sampler for probabilistic distribution tracking. + * + * Maintains a fixed-size reservoir of samples drawn uniformly from the full + * history of values using Vitter's Algorithm R. This gives an unbiased + * statistical representation of the value distribution regardless of how many + * total values have been observed, at the cost of O(ReservoirSize) memory. + * + * A larger reservoir improves accuracy of tail percentiles (P99, P999) but + * increases memory and snapshot cost. The default of 1028 gives good accuracy + * for most telemetry uses. + * + * http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir */ - class UniformSample { public: @@ -159,7 +200,14 @@ private: std::vector> m_Values; }; -/** Track (probabilistic) sample distribution along with min/max +/** Tracks the statistical distribution of a stream of values. + * + * Records exact min, max, count and mean across all values ever seen, plus a + * reservoir sample (via UniformSample) used to compute percentiles. Percentiles + * are therefore probabilistic — they reflect the distribution of a representative + * sample rather than the full history. + * + * All operations are thread-safe via lock-free atomics. */ class Histogram { @@ -183,11 +231,28 @@ private: std::atomic m_Count{0}; }; -/** Track timing and frequency of some operation - - Example usage would be to track frequency and duration of network - requests, or function calls. - +/** Combines a Histogram and a Meter to track both the distribution and rate + * of a recurring operation. + * + * Duration values are stored in hi-freq timer ticks. Use GetHifreqTimerToSeconds() + * when converting for display. + * + * Typical usage via the RAII Scope helper: + * + * OperationTiming MyTiming; + * + * { + * OperationTiming::Scope Scope(MyTiming); + * DoWork(); + * // Scope destructor calls Stop() automatically + * } + * + * // Or cancel if the operation should not be counted: + * { + * OperationTiming::Scope Scope(MyTiming); + * if (CacheHit) { Scope.Cancel(); return; } + * DoExpensiveWork(); + * } */ class OperationTiming { @@ -207,13 +272,19 @@ public: double Rate15() { return m_Meter.Rate15(); } double MeanRate() const { return m_Meter.MeanRate(); } + /** RAII helper that records duration from construction to Stop() or destruction. + * + * Call Cancel() to discard the measurement (e.g. for cache hits that should + * not skew latency statistics). After Stop() or Cancel() the destructor is a + * no-op. + */ struct Scope { Scope(OperationTiming& Outer); ~Scope(); - void Stop(); - void Cancel(); + void Stop(); // Record elapsed time and mark the meter + void Cancel(); // Discard this measurement; destructor becomes a no-op private: OperationTiming& m_Outer; @@ -225,6 +296,7 @@ private: Histogram m_Histogram; }; +/** Immutable snapshot of a Meter's state at a point in time. */ struct MeterSnapshot { uint64_t Count; @@ -234,6 +306,12 @@ struct MeterSnapshot double Rate15; }; +/** Immutable snapshot of a Histogram's state at a point in time. + * + * Count and all statistical values have been scaled by the ConversionFactor + * supplied when the snapshot was taken (e.g. GetHifreqTimerToSeconds() to + * convert timer ticks to seconds). + */ struct HistogramSnapshot { double Count; @@ -246,24 +324,29 @@ struct HistogramSnapshot double P999; }; +/** Combined snapshot of a Meter and Histogram pair. */ struct StatsSnapshot { MeterSnapshot Meter; HistogramSnapshot Histogram; }; +/** Combined snapshot of request timing and byte transfer statistics. */ struct RequestStatsSnapshot { StatsSnapshot Requests; StatsSnapshot Bytes; }; -/** Metrics for network requests - - Aggregates tracking of duration, payload sizes into a single - class - - */ +/** Tracks both the timing and payload size of network requests. + * + * Maintains two independent histogram+meter pairs: one for request duration + * (in hi-freq timer ticks) and one for transferred bytes. Both dimensions + * share the same request count — a single Update() call advances both. + * + * Duration accessors return values in hi-freq timer ticks. Multiply by + * GetHifreqTimerToSeconds() to convert to seconds. + */ class RequestStats { public: @@ -275,9 +358,9 @@ public: // Timing - int64_t MaxDuration() const { return m_BytesHistogram.Max(); } - int64_t MinDuration() const { return m_BytesHistogram.Min(); } - double MeanDuration() const { return m_BytesHistogram.Mean(); } + int64_t MaxDuration() const { return m_RequestTimeHistogram.Max(); } + int64_t MinDuration() const { return m_RequestTimeHistogram.Min(); } + double MeanDuration() const { return m_RequestTimeHistogram.Mean(); } SampleSnapshot DurationSnapshot() const { return m_RequestTimeHistogram.Snapshot(); } double Rate1() { return m_RequestMeter.Rate1(); } double Rate5() { return m_RequestMeter.Rate5(); } @@ -295,14 +378,23 @@ public: double ByteRate15() { return m_BytesMeter.Rate15(); } double ByteMeanRate() const { return m_BytesMeter.MeanRate(); } + /** RAII helper that records duration and byte count from construction to Stop() + * or destruction. + * + * The byte count can be supplied at construction or updated at any point via + * SetBytes() before the scope ends — useful when the response size is not + * known until the operation completes. + * + * Call Cancel() to discard the measurement entirely. + */ struct Scope { Scope(RequestStats& Outer, int64_t Bytes); ~Scope(); void SetBytes(int64_t Bytes) { m_Bytes = Bytes; } - void Stop(); - void Cancel(); + void Stop(); // Record elapsed time and byte count + void Cancel(); // Discard this measurement; destructor becomes a no-op private: RequestStats& m_Outer; diff --git a/src/zentelemetry/stats.cpp b/src/zentelemetry/stats.cpp index fcfcaf45e..a417bb52c 100644 --- a/src/zentelemetry/stats.cpp +++ b/src/zentelemetry/stats.cpp @@ -631,7 +631,7 @@ EmitSnapshot(const HistogramSnapshot& Snapshot, CbObjectWriter& Cbo) { Cbo << "t_count" << Snapshot.Count << "t_avg" << Snapshot.Avg; Cbo << "t_min" << Snapshot.Min << "t_max" << Snapshot.Max; - Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P999; + Cbo << "t_p75" << Snapshot.P75 << "t_p95" << Snapshot.P95 << "t_p99" << Snapshot.P99 << "t_p999" << Snapshot.P999; } void -- cgit v1.2.3 From 2f0d60cb431ffefecf3e0a383528691be74af21b Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Thu, 5 Mar 2026 14:31:27 +0100 Subject: oidctoken tool package (#810) * added OidcToken binary to the build process. The binary is mirrored from p4 and is placed next to the output of the build process. It is also placed in the release zip archives. * also fixed issue with Linux symbol stripping which was introduced in toolchain changes yesterday --- src/zenserver/xmake.lua | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'src') diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 7a9031782..f2ed17f05 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -33,6 +33,7 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("oidctoken") add_packages("nomad") if has_config("zenmimalloc") then @@ -161,6 +162,16 @@ target("zenserver") copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + local oidctoken_pkg = target:pkg("oidctoken") + if oidctoken_pkg then + local installdir = oidctoken_pkg:installdir() + local oidctoken_bin = "OidcToken" + if is_plat("windows") then + oidctoken_bin = "OidcToken.exe" + end + copy_if_newer(path.join(installdir, "bin", oidctoken_bin), path.join(target:targetdir(), oidctoken_bin), oidctoken_bin) + end + local nomad_pkg = target:pkg("nomad") if nomad_pkg then local installdir = nomad_pkg:installdir() -- cgit v1.2.3 From 2275a88da7d0dbcfbc70c6050b7a1417036ea98d Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 6 Mar 2026 07:45:02 +0100 Subject: fix oidctoken exe lookup check (#811) --- src/zen/authutils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/zen/authutils.cpp b/src/zen/authutils.cpp index 23ac70965..534f7952b 100644 --- a/src/zen/authutils.cpp +++ b/src/zen/authutils.cpp @@ -294,7 +294,7 @@ AuthCommandLineOptions::ParseOptions(cxxopts::Options& Ops, } ClientSettings.AccessTokenProvider = httpclientauth::CreateFromStaticToken(ResolvedAccessToken); } - else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); OidcTokenExePath.empty()) + else if (std::filesystem::path OidcTokenExePath = FindOidcTokenExePath(m_OidcTokenAuthExecutablePath); !OidcTokenExePath.empty()) { if (!Quiet) { -- cgit v1.2.3 From 1e731796187ad73b2dee44b48fcecdd487616394 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 6 Mar 2026 10:11:51 +0100 Subject: Claude config, some bug fixes (#813) * Claude config updates * Bug fixes and hardening across `zencore` and `zenhttp`, identified via static analysis. ### zencore - **`ZEN_ASSERT` macro** -- extended to accept an optional string message literal; added `ZEN_ASSERT_MSG_` helper for message formatting. Callers needing runtime fmt-style formatting should use `ZEN_ASSERT_FORMAT`. - **`MpscQueue`** -- fixed `TypeCompatibleStorage` to use a properly-sized `char Storage[sizeof(T)]` array instead of a single `char`; corrected `Data()` to cast `&Storage` rather than `this`; switched cache-line alignment to a fixed constant to avoid GCC's `-Winterference-size` warning. Enabled previously-disabled tests. - **`StringBuilderImpl`** -- initialized `m_Base`/`m_CurPos`/`m_End` to `nullptr`. Fixed `StringCompare` return type (`bool` -> `int`). Fixed `ParseInt` to reject strings with trailing non-numeric characters. Removed deprecated `` include. - **`NiceNumGeneral`** -- replaced `powl()` with integer `IntPow()` to avoid floating-point precision issues. - **`RwLock::ExclusiveLockScope`** -- added move constructor/assignment; initialized `m_Lock` to `nullptr`. - **`Latch::AddCount`** -- fixed variable type (`std::atomic_ptrdiff_t` -> `std::ptrdiff_t` for the return value of `fetch_add`). - **`thread.cpp`** -- fixed Linux `pthread_setname_np` 16-byte name truncation; added null check before dereferencing in `Event::Close()`; fixed `NamedEvent::Close()` to call `close(Fd)` outside the lock region; added null guard in `NamedMutex` destructor; `Sleep()` now returns early for non-positive durations. - **`MD5Stream`** -- was entirely stubbed out (no-op); now correctly calls `MD5Init`/`MD5Update`/`MD5Final`. Fixed `ToHexString` to use the correct string length. Fixed forward declarations. Fixed tests to compare `compare() == 0`. - **`sentryintegration.cpp`** -- guard against null `filename`/`funcname` in spdlog message handler to prevent a crash in `fmt::format`. - **`jobqueue.cpp`** -- fixed lost job ID when `IdGenerator` wraps around zero; fixed raw `Job*` in `RunningJobs` map (potential use-after-free) to `RefPtr`; fixed range-loop copies; fixed format string typo. - **`trace.cpp`** -- suppress GCC false-positive warnings in third-party `trace.h` include. ### zenhttp - **WebSocket close race** (`wsasio`, `wshttpsys`, `httpwsclient`) -- `m_CloseSent` promoted from `bool` to `std::atomic`; close check changed to `exchange(true)` to eliminate the check-then-set data race. - **`wsframecodec.cpp`** -- reject WebSocket frames with payload > 256 MB to prevent OOM from malformed/malicious frames. - **`oidc.cpp`** -- URL-encode refresh token and client ID in token requests (`FormUrlEncode`); parse `end_session_endpoint` and `device_authorization_endpoint` from OIDC discovery document. - **`httpclientcommon.cpp`** -- propagate error code from `AppendData` when flushing the cache buffer. - **`httpclient.h`** -- initialize all uninitialized members (`ErrorCode`, `UploadedBytes`, `DownloadedBytes`, `ElapsedSeconds`, `MultipartBoundary` fields). - **`httpserver.h`** -- fix `operator=` return type for `HttpRpcHandler` (missing `&`). - **`packageformat.h`** -- fix `~0u` (32-bit truncation) to `~uint64_t(0)` for a `uint64_t` field. - **`httpparser`** -- initialize `m_RequestVerb` in both declaration and `ResetState()`. - **`httpplugin.cpp`** -- initialize `m_BasePort`; fix format string missing quotes around connection name. - **`httptracer.h`** -- move `#pragma once` before includes. - **`websocket.h`** -- initialize `WebSocketMessage::Opcode`. ### zenserver - **`hubservice.cpp`** -- fix two `ZEN_ASSERT` calls that incorrectly used fmt-style format args; converted to `ZEN_ASSERT_FORMAT`. --- src/zencore/blake3.cpp | 2 +- src/zencore/commandline.cpp | 1 + src/zencore/include/zencore/md5.h | 2 ++ src/zencore/include/zencore/mpscqueue.h | 20 +++++++------- src/zencore/include/zencore/string.h | 22 +++++++-------- src/zencore/include/zencore/thread.h | 16 ++++++++--- src/zencore/include/zencore/xxhash.h | 2 +- src/zencore/include/zencore/zencore.h | 34 ++++++++++++++--------- src/zencore/jobqueue.cpp | 20 ++++++-------- src/zencore/logging.cpp | 2 +- src/zencore/md5.cpp | 24 ++++++++++------- src/zencore/memtrack/tagtrace.cpp | 2 +- src/zencore/mpscqueue.cpp | 4 +-- src/zencore/sentryintegration.cpp | 9 +++++-- src/zencore/string.cpp | 15 +++++++++-- src/zencore/thread.cpp | 42 +++++++++++++++++++---------- src/zencore/trace.cpp | 9 +++++++ src/zencore/xmake.lua | 7 +---- src/zenhttp/auth/oidc.cpp | 24 ++++++++++++++++- src/zenhttp/clients/httpclientcommon.cpp | 5 +++- src/zenhttp/clients/httpwsclient.cpp | 8 +++--- src/zenhttp/include/zenhttp/httpclient.h | 14 +++++----- src/zenhttp/include/zenhttp/httpserver.h | 2 +- src/zenhttp/include/zenhttp/packageformat.h | 2 +- src/zenhttp/include/zenhttp/websocket.h | 2 +- src/zenhttp/servers/httpparser.cpp | 8 +----- src/zenhttp/servers/httpparser.h | 2 +- src/zenhttp/servers/httpplugin.cpp | 4 +-- src/zenhttp/servers/httptracer.h | 4 +-- src/zenhttp/servers/wsasio.cpp | 6 ++--- src/zenhttp/servers/wsasio.h | 2 +- src/zenhttp/servers/wsframecodec.cpp | 7 +++++ src/zenhttp/servers/wshttpsys.cpp | 6 ++--- src/zenhttp/servers/wshttpsys.h | 2 +- src/zenserver/hub/hubservice.cpp | 5 ++-- 35 files changed, 208 insertions(+), 128 deletions(-) (limited to 'src') diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp index 123918de5..55f9b74af 100644 --- a/src/zencore/blake3.cpp +++ b/src/zencore/blake3.cpp @@ -123,7 +123,7 @@ BLAKE3::ToHexString(StringBuilderBase& outBuilder) const char str[65]; ToHexString(str); - outBuilder.AppendRange(str, &str[65]); + outBuilder.AppendRange(str, &str[StringLength]); return outBuilder; } diff --git a/src/zencore/commandline.cpp b/src/zencore/commandline.cpp index 426cf23d6..718ef9678 100644 --- a/src/zencore/commandline.cpp +++ b/src/zencore/commandline.cpp @@ -14,6 +14,7 @@ ZEN_THIRD_PARTY_INCLUDES_END # include #endif +#include #include namespace zen { diff --git a/src/zencore/include/zencore/md5.h b/src/zencore/include/zencore/md5.h index d934dd86b..3b0b7cae6 100644 --- a/src/zencore/include/zencore/md5.h +++ b/src/zencore/include/zencore/md5.h @@ -43,6 +43,8 @@ public: MD5 GetHash(); private: + // Opaque storage for MD5_CTX (104 bytes, aligned to uint32_t) + alignas(4) uint8_t m_Context[104]; }; void md5_forcelink(); // internal diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h index 19e410d85..d97c433fd 100644 --- a/src/zencore/include/zencore/mpscqueue.h +++ b/src/zencore/include/zencore/mpscqueue.h @@ -22,10 +22,10 @@ namespace zen { template struct TypeCompatibleStorage { - ElementType* Data() { return (ElementType*)this; } - const ElementType* Data() const { return (const ElementType*)this; } + ElementType* Data() { return reinterpret_cast(&Storage); } + const ElementType* Data() const { return reinterpret_cast(&Storage); } - alignas(ElementType) char DataMember; + alignas(ElementType) char Storage[sizeof(ElementType)]; }; /** Fast multi-producer/single-consumer unbounded concurrent queue. @@ -58,7 +58,7 @@ public: Tail = Next; Next = Tail->Next.load(std::memory_order_relaxed); - std::destroy_at((ElementType*)&Tail->Value); + std::destroy_at(Tail->Value.Data()); delete Tail; } } @@ -67,7 +67,7 @@ public: void Enqueue(ArgTypes&&... Args) { Node* New = new Node; - new (&New->Value) ElementType(std::forward(Args)...); + new (New->Value.Data()) ElementType(std::forward(Args)...); Node* Prev = Head.exchange(New, std::memory_order_acq_rel); Prev->Next.store(New, std::memory_order_release); @@ -82,7 +82,7 @@ public: return {}; } - ElementType* ValuePtr = (ElementType*)&Next->Value; + ElementType* ValuePtr = Next->Value.Data(); std::optional Res{std::move(*ValuePtr)}; std::destroy_at(ValuePtr); @@ -100,9 +100,11 @@ private: }; private: - std::atomic Head; // accessed only by producers - alignas(hardware_constructive_interference_size) - Node* Tail; // accessed only by consumer, hence should be on a different cache line than `Head` + // Use a fixed constant to avoid GCC's -Winterference-size warning with std::hardware_destructive_interference_size + static constexpr std::size_t CacheLineSize = 64; + + alignas(CacheLineSize) std::atomic Head; // accessed only by producers + alignas(CacheLineSize) Node* Tail; // accessed only by consumer, separate cache line from Head }; void mpscqueue_forcelink(); diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index 250eb9f56..4deca63ed 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -8,7 +8,6 @@ #include #include #include -#include #include #include #include @@ -51,7 +50,7 @@ StringLength(const wchar_t* str) return wcslen(str); } -inline bool +inline int StringCompare(const char16_t* s1, const char16_t* s2) { char16_t c1, c2; @@ -66,7 +65,7 @@ StringCompare(const char16_t* s1, const char16_t* s2) ++s1; ++s2; } - return uint16_t(c1) - uint16_t(c2); + return int(uint16_t(c1)) - int(uint16_t(c2)); } inline bool @@ -122,10 +121,10 @@ public: StringBuilderImpl() = default; ~StringBuilderImpl(); - StringBuilderImpl(const StringBuilderImpl&) = delete; - StringBuilderImpl(const StringBuilderImpl&&) = delete; + StringBuilderImpl(const StringBuilderImpl&) = delete; + StringBuilderImpl(StringBuilderImpl&&) = delete; const StringBuilderImpl& operator=(const StringBuilderImpl&) = delete; - const StringBuilderImpl& operator=(const StringBuilderImpl&&) = delete; + StringBuilderImpl& operator=(StringBuilderImpl&&) = delete; inline size_t AddUninitialized(size_t Count) { @@ -374,9 +373,9 @@ protected: [[noreturn]] void Fail(const char* FailReason); // note: throws exception - C* m_Base; - C* m_CurPos; - C* m_End; + C* m_Base = nullptr; + C* m_CurPos = nullptr; + C* m_End = nullptr; bool m_IsDynamic = false; bool m_IsExtendable = false; }; @@ -773,8 +772,9 @@ std::optional ParseInt(const std::string_view& Input) { T Out = 0; - const std::from_chars_result Result = std::from_chars(Input.data(), Input.data() + Input.size(), Out); - if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range) + const char* End = Input.data() + Input.size(); + const std::from_chars_result Result = std::from_chars(Input.data(), End, Out); + if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range || Result.ptr != End) { return std::nullopt; } diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h index a1c68b0b2..d0d710ee8 100644 --- a/src/zencore/include/zencore/thread.h +++ b/src/zencore/include/zencore/thread.h @@ -58,7 +58,7 @@ public: } private: - RwLock* m_Lock; + RwLock* m_Lock = nullptr; }; inline auto WithSharedLock(auto&& Fun) @@ -69,6 +69,16 @@ public: struct ExclusiveLockScope { + ExclusiveLockScope(const ExclusiveLockScope& Rhs) = delete; + ExclusiveLockScope(ExclusiveLockScope&& Rhs) : m_Lock(Rhs.m_Lock) { Rhs.m_Lock = nullptr; } + ExclusiveLockScope& operator=(ExclusiveLockScope&& Rhs) + { + ReleaseNow(); + m_Lock = Rhs.m_Lock; + Rhs.m_Lock = nullptr; + return *this; + } + ExclusiveLockScope& operator=(const ExclusiveLockScope& Rhs) = delete; ExclusiveLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireExclusive(); } ~ExclusiveLockScope() { ReleaseNow(); } @@ -82,7 +92,7 @@ public: } private: - RwLock* m_Lock; + RwLock* m_Lock = nullptr; }; inline auto WithExclusiveLock(auto&& Fun) @@ -195,7 +205,7 @@ public: // false positive completion results. void AddCount(std::ptrdiff_t Count) { - std::atomic_ptrdiff_t Old = Counter.fetch_add(Count); + std::ptrdiff_t Old = Counter.fetch_add(Count); ZEN_ASSERT(Old > 0); } diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h index fc55b513b..f79d39b61 100644 --- a/src/zencore/include/zencore/xxhash.h +++ b/src/zencore/include/zencore/xxhash.h @@ -87,7 +87,7 @@ struct XXH3_128Stream } private: - XXH3_state_s m_State; + XXH3_state_s m_State{}; }; struct XXH3_128Stream_deprecated diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h index 177a19fff..a31950b0b 100644 --- a/src/zencore/include/zencore/zencore.h +++ b/src/zencore/include/zencore/zencore.h @@ -70,26 +70,36 @@ protected: } // namespace zen -#define ZEN_ASSERT(x, ...) \ - do \ - { \ - if (x) [[unlikely]] \ - break; \ - zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ +#define ZEN_ASSERT(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \ } while (false) #ifndef NDEBUG -# define ZEN_ASSERT_SLOW(x, ...) \ - do \ - { \ - if (x) [[unlikely]] \ - break; \ - zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ +# define ZEN_ASSERT_SLOW(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, ZEN_ASSERT_MSG_(#x, ##__VA_ARGS__)); \ } while (false) #else # define ZEN_ASSERT_SLOW(x, ...) #endif +// Internal: select between "expr" and "expr: message" forms. +// With no extra args: ZEN_ASSERT_MSG_("expr") -> "expr" +// With a message arg: ZEN_ASSERT_MSG_("expr", "msg") -> "expr" ": " "msg" +// With fmt-style args: ZEN_ASSERT_MSG_("expr", "msg", args...) -> "expr" ": " "msg" +// The extra fmt args are silently discarded here — use ZEN_ASSERT_FORMAT for those. +#define ZEN_ASSERT_MSG_SELECT_(_1, _2, N, ...) N +#define ZEN_ASSERT_MSG_1_(expr) expr +#define ZEN_ASSERT_MSG_2_(expr, msg, ...) expr ": " msg +#define ZEN_ASSERT_MSG_(expr, ...) ZEN_ASSERT_MSG_SELECT_(unused, ##__VA_ARGS__, ZEN_ASSERT_MSG_2_, ZEN_ASSERT_MSG_1_)(expr, ##__VA_ARGS__) + ////////////////////////////////////////////////////////////////////////// #define ZEN_NOT_IMPLEMENTED(...) ZEN_ASSERT(false, __VA_ARGS__) diff --git a/src/zencore/jobqueue.cpp b/src/zencore/jobqueue.cpp index 35724b07a..d6a8a6479 100644 --- a/src/zencore/jobqueue.cpp +++ b/src/zencore/jobqueue.cpp @@ -90,7 +90,7 @@ public: uint64_t NewJobId = IdGenerator.fetch_add(1); if (NewJobId == 0) { - IdGenerator.fetch_add(1); + NewJobId = IdGenerator.fetch_add(1); } RefPtr NewJob(new Job()); NewJob->Queue = this; @@ -129,7 +129,7 @@ public: QueuedJobs.erase(It); } }); - ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: ''", NewJob->Id.Id, NewJob->Name, Ex.what()); + ZEN_ERROR("Failed to schedule job {}:'{}' to job queue. Reason: '{}'", NewJob->Id.Id, NewJob->Name, Ex.what()); throw; } } @@ -221,11 +221,11 @@ public: std::vector Jobs; QueueLock.WithSharedLock([&]() { - for (auto It : RunningJobs) + for (const auto& It : RunningJobs) { Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Running}); } - for (auto It : CompletedJobs) + for (const auto& It : CompletedJobs) { if (IsStale(It.second->EndTick)) { @@ -234,7 +234,7 @@ public: } Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Completed}); } - for (auto It : AbortedJobs) + for (const auto& It : AbortedJobs) { if (IsStale(It.second->EndTick)) { @@ -243,7 +243,7 @@ public: } Jobs.push_back({.Id = JobId{It.first}, .Status = JobStatus::Aborted}); } - for (auto It : QueuedJobs) + for (const auto& It : QueuedJobs) { Jobs.push_back({.Id = It->Id, .Status = JobStatus::Queued}); } @@ -337,7 +337,7 @@ public: std::atomic_bool InitializedFlag = false; RwLock QueueLock; std::deque> QueuedJobs; - std::unordered_map RunningJobs; + std::unordered_map> RunningJobs; std::unordered_map> CompletedJobs; std::unordered_map> AbortedJobs; @@ -429,20 +429,16 @@ JobQueue::ToString(JobStatus Status) { case JobQueue::JobStatus::Queued: return "Queued"sv; - break; case JobQueue::JobStatus::Running: return "Running"sv; - break; case JobQueue::JobStatus::Aborted: return "Aborted"sv; - break; case JobQueue::JobStatus::Completed: return "Completed"sv; - break; default: ZEN_ASSERT(false); + return ""sv; } - return ""sv; } std::unique_ptr diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index e960a2729..ebd68de09 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -303,7 +303,7 @@ GetLogLevel() LoggerRef Default() { - ZEN_ASSERT(TheDefaultLogger); + ZEN_ASSERT(TheDefaultLogger, "logging::InitializeLogging() must be called before using the logger"); return TheDefaultLogger; } diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp index 83ed53fc8..f8cfee3ac 100644 --- a/src/zencore/md5.cpp +++ b/src/zencore/md5.cpp @@ -56,9 +56,9 @@ struct MD5_CTX unsigned char digest[16]; /* actual digest after MD5Final call */ }; -void MD5Init(); -void MD5Update(); -void MD5Final(); +void MD5Init(MD5_CTX* mdContext); +void MD5Update(MD5_CTX* mdContext, unsigned char* inBuf, unsigned int inLen); +void MD5Final(MD5_CTX* mdContext); /* ********************************************************************** @@ -370,28 +370,32 @@ MD5 MD5::Zero; // Initialized to all zeroes MD5Stream::MD5Stream() { + static_assert(sizeof(MD5_CTX) <= sizeof(m_Context)); Reset(); } void MD5Stream::Reset() { + MD5Init(reinterpret_cast(m_Context)); } MD5Stream& MD5Stream::Append(const void* Data, size_t ByteCount) { - ZEN_UNUSED(Data); - ZEN_UNUSED(ByteCount); - + MD5Update(reinterpret_cast(m_Context), (unsigned char*)Data, (unsigned int)ByteCount); return *this; } MD5 MD5Stream::GetHash() { - MD5 md5{}; + MD5_CTX FinalCtx; + memcpy(&FinalCtx, m_Context, sizeof(MD5_CTX)); + MD5Final(&FinalCtx); + MD5 md5{}; + memcpy(md5.Hash, FinalCtx.digest, 16); return md5; } @@ -428,7 +432,7 @@ MD5::ToHexString(StringBuilderBase& outBuilder) const char str[41]; ToHexString(str); - outBuilder.AppendRange(str, &str[40]); + outBuilder.AppendRange(str, &str[StringLength]); return outBuilder; } @@ -470,11 +474,11 @@ TEST_CASE("MD5") MD5::String_t Buffer; Result.ToHexString(Buffer); - CHECK(Output.compare(Buffer)); + CHECK(Output.compare(Buffer) == 0); MD5 Reresult = MD5::FromHexString(Buffer); Reresult.ToHexString(Buffer); - CHECK(Output.compare(Buffer)); + CHECK(Output.compare(Buffer) == 0); } TEST_SUITE_END(); diff --git a/src/zencore/memtrack/tagtrace.cpp b/src/zencore/memtrack/tagtrace.cpp index 70a74365d..fca4a2ec3 100644 --- a/src/zencore/memtrack/tagtrace.cpp +++ b/src/zencore/memtrack/tagtrace.cpp @@ -186,7 +186,7 @@ FTagTrace::AnnounceSpecialTags() const { auto EmitTag = [](const char16_t* DisplayString, int32_t Tag, int32_t ParentTag) { const uint32_t DisplayLen = (uint32_t)StringLength(DisplayString); - UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(ANSICHAR)) + UE_TRACE_LOG(Memory, TagSpec, MemAllocChannel, DisplayLen * sizeof(char16_t)) << TagSpec.Tag(Tag) << TagSpec.Parent(ParentTag) << TagSpec.Display(DisplayString, DisplayLen); }; diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp index f749f1c90..bdd22e20c 100644 --- a/src/zencore/mpscqueue.cpp +++ b/src/zencore/mpscqueue.cpp @@ -7,7 +7,7 @@ namespace zen { -#if ZEN_WITH_TESTS && 0 +#if ZEN_WITH_TESTS TEST_SUITE_BEGIN("core.mpscqueue"); TEST_CASE("mpsc") { @@ -24,4 +24,4 @@ mpscqueue_forcelink() { } -} // namespace zen \ No newline at end of file +} // namespace zen diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 636e182b4..bfff114c3 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -81,8 +81,13 @@ sentry_sink::sink_it_(const spdlog::details::log_msg& msg) } try { - std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname); - sentry_value_t event = sentry_value_new_message_event( + auto MaybeNullString = [](const char* Ptr) { return Ptr ? Ptr : ""; }; + std::string Message = fmt::format("{}\n{}({}) [{}]", + msg.payload, + MaybeNullString(msg.source.filename), + msg.source.line, + MaybeNullString(msg.source.funcname)); + sentry_value_t event = sentry_value_new_message_event( /* level */ MapToSentryLevel[msg.level], /* logger */ nullptr, /* message */ Message.c_str()); diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index 3d0451e27..ed0ba6f46 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -268,6 +268,17 @@ namespace { /* kNicenumTime */ 1000}; } // namespace +uint64_t +IntPow(uint64_t Base, int Exp) +{ + uint64_t Result = 1; + for (int I = 0; I < Exp; ++I) + { + Result *= Base; + } + return Result; +} + /* * Convert a number to an appropriately human-readable output. */ @@ -315,7 +326,7 @@ NiceNumGeneral(uint64_t Num, std::span Buffer, NicenumFormat Format) const char* u = UnitStrings[Format][Index]; - if ((Index == 0) || ((Num % (uint64_t)powl((int)KiloUnit[Format], Index)) == 0)) + if ((Index == 0) || ((Num % IntPow(KiloUnit[Format], Index)) == 0)) { /* * If this is an even multiple of the base, always display @@ -339,7 +350,7 @@ NiceNumGeneral(uint64_t Num, std::span Buffer, NicenumFormat Format) for (int i = 2; i >= 0; i--) { - double Value = (double)Num / (uint64_t)powl((int)KiloUnit[Format], Index); + double Value = (double)Num / IntPow(KiloUnit[Format], Index); /* * Don't print floating point values for time. Note, diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp index 9e3486e49..54459cbaa 100644 --- a/src/zencore/thread.cpp +++ b/src/zencore/thread.cpp @@ -133,7 +133,10 @@ SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) #elif ZEN_PLATFORM_MAC pthread_setname_np(ThreadNameZ.c_str()); #else - pthread_setname_np(pthread_self(), ThreadNameZ.c_str()); + // Linux pthread_setname_np has a 16-byte limit (15 chars + NUL) + StringBuilder<16> LinuxThreadName; + LinuxThreadName << LimitedThreadName.substr(0, 15); + pthread_setname_np(pthread_self(), LinuxThreadName.c_str()); #endif } // namespace zen @@ -233,12 +236,15 @@ Event::Close() #else std::atomic_thread_fence(std::memory_order_acquire); auto* Inner = (EventInner*)m_EventHandle.load(); + if (Inner) { - std::unique_lock Lock(Inner->Mutex); - Inner->bSet.store(true); - m_EventHandle = nullptr; + { + std::unique_lock Lock(Inner->Mutex); + Inner->bSet.store(true); + m_EventHandle = nullptr; + } + delete Inner; } - delete Inner; #endif } @@ -351,7 +357,7 @@ NamedEvent::NamedEvent(std::string_view EventName) intptr_t Packed; Packed = intptr_t(Sem) << 32; Packed |= intptr_t(Fd) & 0xffff'ffff; - m_EventHandle = (void*)Packed; + m_EventHandle = (void*)Packed; #endif ZEN_ASSERT(m_EventHandle != nullptr); } @@ -372,7 +378,9 @@ NamedEvent::Close() #if ZEN_PLATFORM_WINDOWS CloseHandle(m_EventHandle); #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - int Fd = int(intptr_t(m_EventHandle.load()) & 0xffff'ffff); + const intptr_t Handle = intptr_t(m_EventHandle.load()); + const int Fd = int(Handle & 0xffff'ffff); + const int Sem = int(Handle >> 32); if (flock(Fd, LOCK_EX | LOCK_NB) == 0) { @@ -388,11 +396,10 @@ NamedEvent::Close() } flock(Fd, LOCK_UN | LOCK_NB); - close(Fd); - - int Sem = int(intptr_t(m_EventHandle.load()) >> 32); semctl(Sem, 0, IPC_RMID); } + + close(Fd); #endif m_EventHandle = nullptr; @@ -481,9 +488,12 @@ NamedMutex::~NamedMutex() CloseHandle(m_MutexHandle); } #elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC - int Inner = int(intptr_t(m_MutexHandle)); - flock(Inner, LOCK_UN); - close(Inner); + if (m_MutexHandle) + { + int Inner = int(intptr_t(m_MutexHandle)); + flock(Inner, LOCK_UN); + close(Inner); + } #endif } @@ -516,7 +526,6 @@ NamedMutex::Create(std::string_view MutexName) if (flock(Inner, LOCK_EX) != 0) { close(Inner); - Inner = 0; return false; } @@ -583,6 +592,11 @@ GetCurrentThreadId() void Sleep(int ms) { + if (ms <= 0) + { + return; + } + #if ZEN_PLATFORM_WINDOWS ::Sleep(ms); #else diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp index a026974c0..7c195e69f 100644 --- a/src/zencore/trace.cpp +++ b/src/zencore/trace.cpp @@ -10,7 +10,16 @@ # define TRACE_IMPLEMENT 1 # undef _WINSOCK_DEPRECATED_NO_WARNINGS +// GCC false positives in thirdparty trace.h (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) +# if ZEN_COMPILER_GCC +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wstringop-overread" +# pragma GCC diagnostic ignored "-Wdangling-pointer" +# endif # include +# if ZEN_COMPILER_GCC +# pragma GCC diagnostic pop +# endif # include # include diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index ab842f6ed..2f81b7ec8 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -14,12 +14,7 @@ target('zencore') end) set_configdir("include/zencore") add_files("**.cpp") - if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then - -- GCC false positives in thirdparty trace.h (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100137) - add_files("trace.cpp", {unity_ignored = true, force = {cxxflags = {"-Wno-stringop-overread", "-Wno-dangling-pointer"}} }) - else - add_files("trace.cpp", {unity_ignored = true }) - end + add_files("trace.cpp", {unity_ignored = true }) add_files("testing.cpp", {unity_ignored = true }) if has_config("zenrpmalloc") then diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 38e7586ad..23bbc17e8 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -32,6 +32,25 @@ namespace details { using namespace std::literals; +static std::string +FormUrlEncode(std::string_view Input) +{ + std::string Result; + Result.reserve(Input.size()); + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~') + { + Result.push_back(C); + } + else + { + Result.append(fmt::format("%{:02X}", static_cast(C))); + } + } + return Result; +} + OidcClient::OidcClient(const OidcClient::Options& Options) { m_BaseUrl = std::string(Options.BaseUrl); @@ -67,6 +86,8 @@ OidcClient::Initialize() .TokenEndpoint = Json["token_endpoint"].string_value(), .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .EndSessionEndpoint = Json["end_session_endpoint"].string_value(), + .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(), .JwksUri = Json["jwks_uri"].string_value(), .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), @@ -81,7 +102,8 @@ OidcClient::Initialize() OidcClient::RefreshTokenResult OidcClient::RefreshToken(std::string_view RefreshToken) { - const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + const std::string Body = + fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId)); HttpClient Http{m_Config.TokenEndpoint}; diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 9ded23375..6f4c67dd0 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -142,7 +142,10 @@ namespace detail { DataSize -= CopySize; if (m_CacheBufferOffset == CacheBufferSize) { - AppendData(m_CacheBuffer, CacheBufferSize); + if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize)) + { + return Ec; + } if (DataSize > 0) { ZEN_ASSERT(DataSize < CacheBufferSize); diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 36a6f081b..9497dadb8 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -351,9 +351,8 @@ struct HttpWsClient::Impl } // Echo masked close frame if we haven't sent one yet - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); EnqueueWrite(std::move(CloseFrame)); } @@ -479,9 +478,8 @@ struct HttpWsClient::Impl return; } - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); EnqueueWrite(std::move(CloseFrame)); } @@ -515,7 +513,7 @@ struct HttpWsClient::Impl bool m_IsWriting = false; std::atomic m_IsOpen{false}; - bool m_CloseSent = false; + std::atomic m_CloseSent{false}; }; ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index d87082d10..bec4984db 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -128,7 +128,7 @@ public: struct ErrorContext { - int ErrorCode; + int ErrorCode = 0; std::string ErrorMessage; /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ @@ -179,19 +179,19 @@ public: KeyValueMap Header; // The number of bytes sent as part of the request - int64_t UploadedBytes; + int64_t UploadedBytes = 0; // The number of bytes received as part of the response - int64_t DownloadedBytes; + int64_t DownloadedBytes = 0; // The elapsed time in seconds for the request to execute - double ElapsedSeconds; + double ElapsedSeconds = 0.0; struct MultipartBoundary { - uint64_t OffsetInPayload; - uint64_t RangeOffset; - uint64_t RangeLength; + uint64_t OffsetInPayload = 0; + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; HttpContentType ContentType; }; diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 02cccc540..c1152dc3e 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -440,7 +440,7 @@ public: ~HttpRpcHandler(); HttpRpcHandler(const HttpRpcHandler&) = delete; - HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + HttpRpcHandler& operator=(const HttpRpcHandler&) = delete; void AddRpc(std::string_view RpcId, std::function HandlerFunction); diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index c90b840da..1a5068580 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -68,7 +68,7 @@ struct CbAttachmentEntry struct CbAttachmentReferenceHeader { uint64_t PayloadByteOffset = 0; - uint64_t PayloadByteSize = ~0u; + uint64_t PayloadByteSize = ~uint64_t(0); uint16_t AbsolutePathLength = 0; // This header will be followed by UTF8 encoded absolute path to backing file diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 7a6fb33dd..bc3293282 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -22,7 +22,7 @@ enum class WebSocketOpcode : uint8_t struct WebSocketMessage { - WebSocketOpcode Opcode; + WebSocketOpcode Opcode = WebSocketOpcode::kText; IoBuffer Payload; uint16_t CloseCode = 0; }; diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 3b1229375..918b55dc6 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -245,13 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } - - // NOTE: this check is redundant given the enclosing if, - // need to verify the intent of this code - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } } else if (!NormalizedUrl.empty()) { @@ -389,6 +382,7 @@ HttpRequestParser::ResetState() m_UpgradeHeaderIndex = -1; m_SecWebSocketKeyHeaderIndex = -1; m_SecWebSocketVersionHeaderIndex = -1; + m_RequestVerb = HttpVerb::kGet; m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index d40a5aeb0..23ad9d8fb 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -93,7 +93,7 @@ private: int8_t m_UpgradeHeaderIndex; int8_t m_SecWebSocketKeyHeaderIndex; int8_t m_SecWebSocketVersionHeaderIndex; - HttpVerb m_RequestVerb; + HttpVerb m_RequestVerb = HttpVerb::kGet; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; int m_RequestId = -1; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 8564826d6..850dafdca 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -123,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer bool m_IsRequestLoggingEnabled = false; LoggerRef m_RequestLog; std::atomic_uint32_t m_ConnectionIdCounter{0}; - int m_BasePort; + int m_BasePort = 0; HttpServerTracer m_RequestTracer; @@ -294,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug ConnectionName = "anonymous"; } - ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName); + ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName); } uint32_t diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h index da72c79c9..a9a45f162 100644 --- a/src/zenhttp/servers/httptracer.h +++ b/src/zenhttp/servers/httptracer.h @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include - #pragma once +#include + namespace zen { /** Helper class for HTTP server implementations diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index dfc1eac38..3e31b58bc 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -140,9 +140,8 @@ WsAsioConnection::ProcessReceivedData() } // Echo close frame back if we haven't sent one yet - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code); EnqueueWrite(std::move(CloseFrame)); } @@ -208,9 +207,8 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) return; } - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; std::vector CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); EnqueueWrite(std::move(CloseFrame)); } diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h index a638ea836..d8ffdc00a 100644 --- a/src/zenhttp/servers/wsasio.h +++ b/src/zenhttp/servers/wsasio.h @@ -65,7 +65,7 @@ private: bool m_IsWriting = false; std::atomic m_IsOpen{true}; - bool m_CloseSent = false; + std::atomic m_CloseSent{false}; }; } // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp index a4c5e0f16..e452141fe 100644 --- a/src/zenhttp/servers/wsframecodec.cpp +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -51,6 +51,13 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) HeaderSize = 10; } + // Reject frames with unreasonable payload sizes to prevent OOM + static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB + if (PayloadLen > kMaxPayloadSize) + { + return {}; + } + const size_t MaskSize = Masked ? 4 : 0; const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index 3f0f0b447..3408b64b3 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -217,9 +217,8 @@ WsHttpSysConnection::ProcessReceivedData() bool ShouldSendClose = false; { RwLock::ExclusiveLockScope _(m_WriteLock); - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; ShouldSendClose = true; } } @@ -412,9 +411,8 @@ WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) bool ShouldSendClose = false; { RwLock::ExclusiveLockScope _(m_WriteLock); - if (!m_CloseSent) + if (!m_CloseSent.exchange(true)) { - m_CloseSent = true; ShouldSendClose = true; } } diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h index ab0ca381a..d854289e0 100644 --- a/src/zenhttp/servers/wshttpsys.h +++ b/src/zenhttp/servers/wshttpsys.h @@ -96,7 +96,7 @@ private: Ref m_SelfRef; std::atomic m_ShutdownRequested{false}; std::atomic m_IsOpen{true}; - bool m_CloseSent = false; + std::atomic m_CloseSent{false}; }; } // namespace zen diff --git a/src/zenserver/hub/hubservice.cpp b/src/zenserver/hub/hubservice.cpp index a757cd594..7b999ae20 100644 --- a/src/zenserver/hub/hubservice.cpp +++ b/src/zenserver/hub/hubservice.cpp @@ -4,6 +4,7 @@ #include "hydration.h" +#include #include #include #include @@ -195,7 +196,7 @@ StorageServerInstance::~StorageServerInstance() void StorageServerInstance::SpawnServerProcess() { - ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); + ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); m_ServerInstance.SetServerExecutablePath(GetRunningExecutablePath()); m_ServerInstance.SetDataDir(m_BaseDir); @@ -322,7 +323,7 @@ StorageServerInstance::WakeLocked() return; } - ZEN_ASSERT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); + ZEN_ASSERT_FORMAT(!m_ServerInstance.IsRunning(), "Storage server instance for module '{}' is already running", m_ModuleId); try { -- cgit v1.2.3 From 5115b419cefd41e8d5cc465c8c7ae5140cde71d4 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Fri, 6 Mar 2026 12:39:06 +0100 Subject: zenstore bug-fixes from static analysis pass (#815) **Bug fixes across zenstore, zenremotestore, and related subsystems, primarily surfaced by static analysis.** ## Cache subsystem (cachedisklayer.cpp) - Fixed tombstone scoping bug: tombstone flag and missing entry were recorded outside the block where data was removed, causing non-missing entries to be incorrectly tombstoned - Fixed use-after-overwrite: `RemoveMemCachedData`/`RemoveMetaData` were called after `Payload` was overwritten on cache put, leaking stale data - Fixed incorrect retry sleep formula (`100 - (3 - RetriesLeft) * 100` always produced the same or negative value; corrected to `(3 - RetriesLeft) * 100`) - Fixed broken `break` missing from sidecar file read loop, causing reads past valid data - Fixed missing format argument in three `ZEN_WARN`/`ZEN_ERROR` log calls (format string had `{}` placeholders with no corresponding argument, or vice versa) - Fixed elapsed timer being accumulated inside the wrong scope in `HandleRpcGetCacheRecords` - Fixed test asserting against unserialized `RecordPolicy` instead of the deserialized `Loaded` copy - Initialized `AbortFlag`/`PauseFlag` atomics at declaration (UB if read before first write) ## Build store (buildstore.cpp / buildstore.h) - Fixed wrong variable used in warning log: used loop index `ResultIndex` instead of `Index`/`MetaLocationResultIndexes[Index]`, logging wrong hash values - Fixed `sizeof(AccessTimesHeader)` used instead of `sizeof(AccessTimeRecord)` when advancing write offset, corrupting the access times file if the sizes differ - Initialized `m_LastAccessTimeUpdateCount` atomic member (was uninitialized) - Changed map iteration loops to use `const auto&` to avoid unnecessary copies ## Project store (projectstore.cpp / projectstore.h) - Fixed wrong iterator dereferenced in `IterateChunks`: used `ChunkIt->second` (from a different map lookup) instead of `MetaIt->second` - Fixed wrong assert variable: `Sizes[Index]` should be `RawSizes[Index]` - Fixed `MakeTombstone`/`IsTombstone` inconsistency: `MakeTombstone` was zeroing `OpLsn` but `IsTombstone` checks `OpLsn.Number != 0`; tombstone creation now preserves `OpLsn` - Fixed uninitialized `InvalidEntries` counter - Fixed format string mismatch in warning log - Initialized `AbortFlag`/`PauseFlag` atomics; changed map iteration to `const auto&` ## Workspaces (workspaces.cpp) - Fixed missing alias registration when a workspace share is updated: alias was deleted but never re-inserted - Fixed integer overflow in range clamping: `(RequestedOffset + RequestedSize) > Size` could wrap; corrected to `RequestedSize > Size - RequestedOffset` - Changed map iteration loops to `const auto&` ## CAS subsystem (cas.cpp, caslog.cpp, compactcas.cpp, filecas.cpp) - Fixed `IterateChunks` passing original `Payload` buffer instead of the modified `Chunk` buffer (content type was set on the copy but the original was sent to the callback) - Fixed invalid `std::future::get()` call on default-constructed futures - Fixed sign-comparison in `CasLogFile::Replay` loop (`int i` vs `size_t`) - Changed `CasLogFile::IsValid` and `Open` to take `const std::filesystem::path&` instead of by value - Fixed format string in `~CasContainerStrategy` error log ## Remote store (zenremotestore) - Fixed `FolderContent::operator==` always returning true: loop variable `PathCount` was initialized to 0 instead of `Paths.size()` - Fixed `GetChunkIndexForRawHash` looking up from wrong map (`RawHashToSequenceIndex` instead of `ChunkHashToChunkIndex`) - Fixed double-counted `UniqueSequencesFound` stat (incremented in both branches of an if/else) - Fixed `RawSize` sentinel value truncation: `(uint32_t)-1` assigned to a `uint64_t` field; corrected to `(uint64_t)-1` - Initialized uninitialized atomic and struct members across `buildstorageoperations.h`, `chunkblock.h`, and `remoteprojectstore.h` --- src/zenremotestore/chunking/chunkedcontent.cpp | 4 +-- .../zenremotestore/builds/buildstorageoperations.h | 6 ++-- .../include/zenremotestore/chunking/chunkblock.h | 2 +- .../zenremotestore/chunking/chunkedcontent.h | 2 +- .../projectstore/remoteprojectstore.h | 20 ++++++------- src/zenstore/buildstore/buildstore.cpp | 16 +++++----- src/zenstore/cache/cachedisklayer.cpp | 32 ++++++++++---------- src/zenstore/cache/cachepolicy.cpp | 14 ++++----- src/zenstore/cache/cacherpc.cpp | 4 +-- src/zenstore/cache/structuredcachestore.cpp | 8 ++--- src/zenstore/cas.cpp | 11 ++++--- src/zenstore/caslog.cpp | 6 ++-- src/zenstore/cidstore.cpp | 3 +- src/zenstore/compactcas.cpp | 16 +++++----- src/zenstore/filecas.cpp | 14 ++++----- src/zenstore/filecas.h | 2 +- .../include/zenstore/buildstore/buildstore.h | 2 +- .../include/zenstore/cache/cachedisklayer.h | 34 +++++++++++----------- src/zenstore/include/zenstore/cache/cacheshared.h | 6 ++-- .../include/zenstore/cache/structuredcachestore.h | 12 ++++---- src/zenstore/include/zenstore/caslog.h | 10 +++---- src/zenstore/include/zenstore/gc.h | 6 ++-- src/zenstore/include/zenstore/projectstore.h | 10 ++----- src/zenstore/projectstore.cpp | 26 ++++++++--------- src/zenstore/workspaces.cpp | 16 ++++++---- 25 files changed, 144 insertions(+), 138 deletions(-) (limited to 'src') diff --git a/src/zenremotestore/chunking/chunkedcontent.cpp b/src/zenremotestore/chunking/chunkedcontent.cpp index 62c927508..c09ab9d3a 100644 --- a/src/zenremotestore/chunking/chunkedcontent.cpp +++ b/src/zenremotestore/chunking/chunkedcontent.cpp @@ -166,7 +166,6 @@ namespace { if (Chunked.Info.ChunkSequence.empty()) { AddChunkSequence(Stats, OutChunkedContent.ChunkedContent, ChunkHashToChunkIndex, Chunked.Info.RawHash, RawSize); - Stats.UniqueSequencesFound++; } else { @@ -186,7 +185,6 @@ namespace { Chunked.Info.ChunkHashes, ChunkSizes); } - Stats.UniqueSequencesFound++; } }); Stats.FilesChunked++; @@ -253,7 +251,7 @@ FolderContent::operator==(const FolderContent& Rhs) const if ((Platform == Rhs.Platform) && (RawSizes == Rhs.RawSizes) && (Attributes == Rhs.Attributes) && (ModificationTicks == Rhs.ModificationTicks) && (Paths.size() == Rhs.Paths.size())) { - size_t PathCount = 0; + size_t PathCount = Paths.size(); for (size_t PathIndex = 0; PathIndex < PathCount; PathIndex++) { if (Paths[PathIndex].generic_string() != Rhs.Paths[PathIndex].generic_string()) diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h index 875b8593b..0d2eded58 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageoperations.h @@ -161,7 +161,7 @@ public: DownloadStatistics m_DownloadStats; WriteChunkStatistics m_WriteChunkStats; RebuildFolderStateStatistics m_RebuildFolderStateStats; - std::atomic m_WrittenChunkByteCount; + std::atomic m_WrittenChunkByteCount = 0; private: struct BlockWriteOps @@ -186,7 +186,7 @@ private: uint32_t ScavengedContentIndex = (uint32_t)-1; uint32_t ScavengedPathIndex = (uint32_t)-1; uint32_t RemoteSequenceIndex = (uint32_t)-1; - uint64_t RawSize = (uint32_t)-1; + uint64_t RawSize = (uint64_t)-1; }; struct CopyChunkData @@ -362,7 +362,7 @@ private: const std::filesystem::path m_TempDownloadFolderPath; const std::filesystem::path m_TempBlockFolderPath; - std::atomic m_ValidatedChunkByteCount; + std::atomic m_ValidatedChunkByteCount = 0; }; struct FindBlocksStatistics diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index 7aae1442e..20b6fd371 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -24,7 +24,7 @@ struct ThinChunkBlockDescription struct ChunkBlockDescription : public ThinChunkBlockDescription { - uint64_t HeaderSize; + uint64_t HeaderSize = 0; std::vector ChunkRawLengths; std::vector ChunkCompressedLengths; }; diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h index d402bd3f0..f44381e42 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkedcontent.h @@ -231,7 +231,7 @@ GetSequenceIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& Raw inline uint32_t GetChunkIndexForRawHash(const ChunkedContentLookup& Lookup, const IoHash& RawHash) { - return Lookup.RawHashToSequenceIndex.at(RawHash); + return Lookup.ChunkHashToChunkIndex.at(RawHash); } inline uint32_t diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 2cf10c664..42786d0f2 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -92,22 +92,22 @@ public: struct RemoteStoreInfo { - bool CreateBlocks; - bool UseTempBlockFiles; - bool AllowChunking; + bool CreateBlocks = false; + bool UseTempBlockFiles = false; + bool AllowChunking = false; std::string ContainerName; std::string Description; }; struct Stats { - std::uint64_t m_SentBytes; - std::uint64_t m_ReceivedBytes; - std::uint64_t m_RequestTimeNS; - std::uint64_t m_RequestCount; - std::uint64_t m_PeakSentBytes; - std::uint64_t m_PeakReceivedBytes; - std::uint64_t m_PeakBytesPerSec; + std::uint64_t m_SentBytes = 0; + std::uint64_t m_ReceivedBytes = 0; + std::uint64_t m_RequestTimeNS = 0; + std::uint64_t m_RequestCount = 0; + std::uint64_t m_PeakSentBytes = 0; + std::uint64_t m_PeakReceivedBytes = 0; + std::uint64_t m_PeakBytesPerSec = 0; }; struct ExtendedStats diff --git a/src/zenstore/buildstore/buildstore.cpp b/src/zenstore/buildstore/buildstore.cpp index 49ed7cdd2..dff1c3c61 100644 --- a/src/zenstore/buildstore/buildstore.cpp +++ b/src/zenstore/buildstore/buildstore.cpp @@ -373,8 +373,8 @@ BuildStore::PutMetadatas(std::span BlobHashes, std::span AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); for (size_t Index = 0; Index < Metadatas.size(); Index++) { @@ -505,8 +505,8 @@ BuildStore::GetMetadatas(std::span BlobHashes, WorkerThreadPool* O else { ZEN_WARN("Metadata {} for blob {} is malformed (not a compressed binary format)", - MetadataHashes[ResultIndex], - BlobHashes[ResultIndex]); + MetadataHashes[Index], + BlobHashes[MetaLocationResultIndexes[Index]]); } } } @@ -561,7 +561,7 @@ BuildStore::GetStorageStats() const RwLock::SharedLockScope _(m_Lock); Result.EntryCount = m_BlobLookup.size(); - for (auto LookupIt : m_BlobLookup) + for (const auto& LookupIt : m_BlobLookup) { const BlobIndex ReadBlobIndex = LookupIt.second; const BlobEntry& ReadBlobEntry = m_BlobEntries[ReadBlobIndex]; @@ -634,7 +634,7 @@ BuildStore::CompactState() const size_t MetadataCount = m_MetadataEntries.size(); MetadataEntries.reserve(MetadataCount); - for (auto LookupIt : m_BlobLookup) + for (const auto& LookupIt : m_BlobLookup) { const IoHash& BlobHash = LookupIt.first; const BlobIndex ReadBlobIndex = LookupIt.second; @@ -955,7 +955,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files std::vector AccessRecords; AccessRecords.reserve(Header.AccessTimeCount); - for (auto It : m_BlobLookup) + for (const auto& It : m_BlobLookup) { const IoHash& Key = It.first; const BlobIndex Index = It.second; @@ -965,7 +965,7 @@ BuildStore::WriteAccessTimes(const RwLock::ExclusiveLockScope&, const std::files } uint64_t RecordsSize = sizeof(AccessTimeRecord) * Header.AccessTimeCount; TempFile.Write(AccessRecords.data(), RecordsSize, Offset); - Offset += sizeof(AccessTimesHeader) * Header.AccessTimeCount; + Offset += sizeof(AccessTimeRecord) * Header.AccessTimeCount; } if (TempFile.MoveTemporaryIntoPlace(AccessTimesPath, Ec); Ec) { diff --git a/src/zenstore/cache/cachedisklayer.cpp b/src/zenstore/cache/cachedisklayer.cpp index b73b3e6fb..d53f9f369 100644 --- a/src/zenstore/cache/cachedisklayer.cpp +++ b/src/zenstore/cache/cachedisklayer.cpp @@ -602,7 +602,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B if (FileSize < sizeof(BucketMetaHeader)) { - ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: ", + ZEN_WARN("Failed to read sidecar file '{}'. Minimum size {} expected, actual size: {}", SidecarPath, sizeof(BucketMetaHeader), FileSize); @@ -654,6 +654,7 @@ BucketManifestSerializer::ReadSidecarFile(RwLock::ExclusiveLockScope& B SidecarPath, sizeof(ManifestData), CurrentReadOffset); + break; } CurrentReadOffset += sizeof(ManifestData); @@ -1011,7 +1012,7 @@ ZenCacheDiskLayer::CacheBucket::WriteIndexSnapshotLocked(uint64_t LogPosi { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1267,10 +1268,10 @@ ZenCacheDiskLayer::CacheBucket::InitializeIndexFromDisk(RwLock::ExclusiveLockSco { RemoveMemCachedData(IndexLock, Payload); RemoveMetaData(IndexLock, Payload); + Location.Flags |= DiskLocation::kTombStone; + MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location}); } } - Location.Flags |= DiskLocation::kTombStone; - MissingEntries.push_back(DiskIndexEntry{.Key = It.first, .Location = Location}); } ZEN_ASSERT(!MissingEntries.empty()); @@ -2812,7 +2813,7 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c m_BucketDir, Ec.message(), RetriesLeft); - Sleep(100 - (3 - RetriesLeft) * 100); // Total 600 ms + Sleep((3 - RetriesLeft) * 100); // Total 600 ms Ec.clear(); DataFile.MoveTemporaryIntoPlace(FsPath, Ec); RetriesLeft--; @@ -2866,11 +2867,12 @@ ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, c { EntryIndex = It.value(); ZEN_ASSERT_SLOW(EntryIndex < PayloadIndex(m_AccessTimes.size())); - BucketPayload& Payload = m_Payloads[EntryIndex]; - uint64_t OldSize = Payload.Location.Size(); + BucketPayload& Payload = m_Payloads[EntryIndex]; + uint64_t OldSize = Payload.Location.Size(); + RemoveMemCachedData(IndexLock, Payload); + RemoveMetaData(IndexLock, Payload); Payload = BucketPayload{.Location = Loc}; m_AccessTimes[EntryIndex] = GcClock::TickCount(); - RemoveMemCachedData(IndexLock, Payload); m_StandaloneSize.fetch_sub(OldSize, std::memory_order::relaxed); } if ((Value.RawSize != 0 || Value.RawHash != IoHash::Zero) && Value.RawSize <= std::numeric_limits::max()) @@ -3521,7 +3523,7 @@ ZenCacheDiskLayer::CacheBucket::GetReferences(const LoggerRef& Logger, } else { - ZEN_WARN("Cache record {} payload is malformed. Reason: ", RawHash, ToString(Error)); + ZEN_WARN("Cache record {} payload is malformed. Reason: {}", RawHash, ToString(Error)); } return false; }; @@ -4282,8 +4284,8 @@ ZenCacheDiskLayer::DiscoverBuckets() RwLock SyncLock; WorkerThreadPool& Pool = GetLargeWorkerPool(EWorkloadType::Burst); - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4454,8 +4456,8 @@ ZenCacheDiskLayer::Flush() } { WorkerThreadPool& Pool = GetMediumWorkerPool(EWorkloadType::Burst); - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4496,8 +4498,8 @@ ZenCacheDiskLayer::Scrub(ScrubContext& Ctx) RwLock::SharedLockScope _(m_Lock); - std::atomic Abort; - std::atomic Pause; + std::atomic Abort{false}; + std::atomic Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try diff --git a/src/zenstore/cache/cachepolicy.cpp b/src/zenstore/cache/cachepolicy.cpp index ce6a14bd9..c1e7dc5b3 100644 --- a/src/zenstore/cache/cachepolicy.cpp +++ b/src/zenstore/cache/cachepolicy.cpp @@ -403,13 +403,13 @@ TEST_CASE("cacherecordpolicy") RecordPolicy.Save(Writer); CbObject Saved = Writer.Save()->AsObject(); CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); - CHECK(!RecordPolicy.IsUniform()); - CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); - CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); - CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); - CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); - CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); - CHECK(RecordPolicy.GetValuePolicies().size() == 2); + CHECK(!Loaded.IsUniform()); + CHECK(Loaded.GetRecordPolicy() == UnionPolicy); + CHECK(Loaded.GetBasePolicy() == DefaultPolicy); + CHECK(Loaded.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(Loaded.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(Loaded.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(Loaded.GetValuePolicies().size() == 2); } } diff --git a/src/zenstore/cache/cacherpc.cpp b/src/zenstore/cache/cacherpc.cpp index e1fd0a3e6..90c5a5e60 100644 --- a/src/zenstore/cache/cacherpc.cpp +++ b/src/zenstore/cache/cacherpc.cpp @@ -866,8 +866,8 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb Request.Complete = false; } } - Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); } + Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); }; m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete)); @@ -934,7 +934,7 @@ CacheRpcHandler::HandleRpcGetCacheRecords(const CacheRequestContext& Context, Cb *Namespace, Key.Bucket, Key.Hash, - Request.RecordObject ? ""sv : " (PARTIAL)"sv, + Request.RecordObject ? " (PARTIAL)"sv : ""sv, Request.Source ? Request.Source->Url : "LOCAL"sv, NiceLatencyNs(Request.ElapsedTimeUs * 1000)); m_CacheStats.MissCount++; diff --git a/src/zenstore/cache/structuredcachestore.cpp b/src/zenstore/cache/structuredcachestore.cpp index 18023e2d6..cff0e9a35 100644 --- a/src/zenstore/cache/structuredcachestore.cpp +++ b/src/zenstore/cache/structuredcachestore.cpp @@ -686,8 +686,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context, return false; } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -722,8 +722,8 @@ ZenCacheStore::Get(const CacheRequestContext& Context, } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get [{}], bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -790,8 +790,8 @@ ZenCacheStore::Put(const CacheRequestContext& Context, } ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put [{}] bucket '{}', key '{}'", - Context, Namespace, + Context, Bucket, HashKey.ToHexString()); @@ -816,7 +816,7 @@ ZenCacheStore::DropNamespace(std::string_view InNamespace) { std::function PostDropOp; { - RwLock::SharedLockScope _(m_NamespacesLock); + RwLock::ExclusiveLockScope _(m_NamespacesLock); if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end()) { ZenCacheNamespace& Namespace = *It->second; diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp index f80952322..8855c87d8 100644 --- a/src/zenstore/cas.cpp +++ b/src/zenstore/cas.cpp @@ -153,7 +153,10 @@ CasImpl::Initialize(const CidStoreConfiguration& InConfig) } for (std::future& Result : Work) { - Result.get(); + if (Result.valid()) + { + Result.get(); + } } } } @@ -426,7 +429,7 @@ CasImpl::IterateChunks(std::span DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool, LargeSizeLimit == 0 ? m_Config.HugeValueThreshold : Min(LargeSizeLimit, m_Config.HugeValueThreshold))) @@ -439,7 +442,7 @@ CasImpl::IterateChunks(std::span DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool, LargeSizeLimit == 0 ? m_Config.TinyValueThreshold : Min(LargeSizeLimit, m_Config.TinyValueThreshold))) @@ -452,7 +455,7 @@ CasImpl::IterateChunks(std::span DecompressedIds, [&](size_t Index, const IoBuffer& Payload) { IoBuffer Chunk(Payload); Chunk.SetContentType(ZenContentType::kCompressedBinary); - return AsyncCallback(Index, Payload); + return AsyncCallback(Index, Chunk); }, OptionalWorkerPool)) { diff --git a/src/zenstore/caslog.cpp b/src/zenstore/caslog.cpp index 492ce9317..44664dac2 100644 --- a/src/zenstore/caslog.cpp +++ b/src/zenstore/caslog.cpp @@ -35,7 +35,7 @@ CasLogFile::~CasLogFile() } bool -CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize) +CasLogFile::IsValid(const std::filesystem::path& FileName, size_t RecordSize) { if (!IsFile(FileName)) { @@ -71,7 +71,7 @@ CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize) } void -CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode) +CasLogFile::Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode) { m_RecordSize = RecordSize; @@ -205,7 +205,7 @@ CasLogFile::Replay(std::function&& Handler, uint64_t SkipEntr m_File.Read(ReadBuffer.data(), BytesToRead, LogBaseOffset + ReadOffset); - for (int i = 0; i < int(EntriesToRead); ++i) + for (size_t i = 0; i < EntriesToRead; ++i) { Handler(ReadBuffer.data() + (i * m_RecordSize)); } diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp index bedf91287..b20d8f565 100644 --- a/src/zenstore/cidstore.cpp +++ b/src/zenstore/cidstore.cpp @@ -48,13 +48,13 @@ struct CidStore::Impl std::vector AddChunks(std::span ChunkDatas, std::span RawHashes, CidStore::InsertMode Mode) { + ZEN_ASSERT(ChunkDatas.size() == RawHashes.size()); if (ChunkDatas.size() == 1) { std::vector Result(1); Result[0] = AddChunk(ChunkDatas[0], RawHashes[0], Mode); return Result; } - ZEN_ASSERT(ChunkDatas.size() == RawHashes.size()); std::vector Chunks; Chunks.reserve(ChunkDatas.size()); #if ZEN_BUILD_DEBUG @@ -81,6 +81,7 @@ struct CidStore::Impl m_CasStore.InsertChunks(Chunks, RawHashes, static_cast(Mode)); ZEN_ASSERT(CasResults.size() == ChunkDatas.size()); std::vector Result; + Result.reserve(CasResults.size()); for (const CasStore::InsertResult& CasResult : CasResults) { if (CasResult.New) diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp index 21411df59..b09892687 100644 --- a/src/zenstore/compactcas.cpp +++ b/src/zenstore/compactcas.cpp @@ -153,7 +153,7 @@ CasContainerStrategy::~CasContainerStrategy() } catch (const std::exception& Ex) { - ZEN_ERROR("~CasContainerStrategy failed with: ", Ex.what()); + ZEN_ERROR("~CasContainerStrategy failed with: {}", Ex.what()); } m_Gc.RemoveGcReferenceStore(*this); m_Gc.RemoveGcStorage(this); @@ -440,9 +440,9 @@ CasContainerStrategy::IterateChunks(std::span ChunkHas return true; } - std::atomic AbortFlag; + std::atomic AbortFlag{false}; { - std::atomic PauseFlag; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -559,8 +559,8 @@ CasContainerStrategy::ScrubStorage(ScrubContext& Ctx) std::vector ChunkLocations; std::vector ChunkIndexToChunkHash; - std::atomic Abort; - std::atomic Pause; + std::atomic Abort{false}; + std::atomic Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -1007,7 +1007,7 @@ CasContainerStrategy::CompactIndex(RwLock::ExclusiveLockScope&) std::vector Locations; Locations.reserve(EntryCount); LocationMap.reserve(EntryCount); - for (auto It : m_LocationMap) + for (const auto& It : m_LocationMap) { size_t EntryIndex = Locations.size(); Locations.push_back(m_Locations[It.second]); @@ -1106,7 +1106,7 @@ CasContainerStrategy::MakeIndexSnapshot(bool ResetLog) { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1136,7 +1136,7 @@ CasContainerStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint uint64_t Size = ObjectIndexFile.FileSize(); if (Size >= sizeof(CasDiskIndexHeader)) { - uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CasDiskIndexHeader))) / sizeof(CasDiskIndexEntry); + uint64_t ExpectedEntryCount = (Size - sizeof(CasDiskIndexHeader)) / sizeof(CasDiskIndexEntry); CasDiskIndexHeader Header; ObjectIndexFile.Read(&Header, sizeof(Header), 0); if ((Header.Magic == CasDiskIndexHeader::ExpectedMagic) && (Header.Version == CasDiskIndexHeader::CurrentVersion) && diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp index 295451818..0088afe6e 100644 --- a/src/zenstore/filecas.cpp +++ b/src/zenstore/filecas.cpp @@ -383,7 +383,7 @@ FileCasStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, CasStore:: HRESULT WriteRes = PayloadFile.Write(Cursor, Size); if (FAILED(WriteRes)) { - ThrowSystemException(hRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath)); + ThrowSystemException(WriteRes, fmt::format("failed to write {} bytes to shard file '{}'", ChunkSize, ChunkPath)); } }; #else @@ -669,8 +669,8 @@ FileCasStrategy::IterateChunks(std::span ChunkHashes, return true; }; - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -823,8 +823,8 @@ FileCasStrategy::ScrubStorage(ScrubContext& Ctx) ZEN_INFO("discovered {} files @ '{}' ({} not in index), scrubbing", m_Index.size(), m_RootDirectory, DiscoveredFilesNotInIndex); - std::atomic Abort; - std::atomic Pause; + std::atomic Abort{false}; + std::atomic Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -1016,7 +1016,7 @@ FileCasStrategy::MakeIndexSnapshot(bool ResetLog) { // This is non-critical, it only means that we will replay the events of the log over the snapshot - inefficent but in // the end it will be the same result - ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, IndexPath, Ec.message()); + ZEN_WARN("Snapshot failed to clean log file '{}', reason: '{}'", LogPath, Ec.message()); } m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); } @@ -1052,7 +1052,7 @@ FileCasStrategy::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& uint64_t Size = ObjectIndexFile.FileSize(); if (Size >= sizeof(FileCasIndexHeader)) { - uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(FileCasIndexHeader))) / sizeof(FileCasIndexEntry); + uint64_t ExpectedEntryCount = (Size - sizeof(FileCasIndexHeader)) / sizeof(FileCasIndexEntry); FileCasIndexHeader Header; ObjectIndexFile.Read(&Header, sizeof(Header), 0); if ((Header.Magic == FileCasIndexHeader::ExpectedMagic) && (Header.Version == FileCasIndexHeader::CurrentVersion) && diff --git a/src/zenstore/filecas.h b/src/zenstore/filecas.h index e93356927..41756b65f 100644 --- a/src/zenstore/filecas.h +++ b/src/zenstore/filecas.h @@ -74,7 +74,7 @@ private: { static const uint32_t kTombStone = 0x0000'0001; - bool IsFlagSet(const uint32_t Flag) const { return (Flags & kTombStone) == Flag; } + bool IsFlagSet(const uint32_t Flag) const { return (Flags & Flag) == Flag; } IoHash Key; uint32_t Flags = 0; diff --git a/src/zenstore/include/zenstore/buildstore/buildstore.h b/src/zenstore/include/zenstore/buildstore/buildstore.h index bfc83ba0d..ea2ef7f89 100644 --- a/src/zenstore/include/zenstore/buildstore/buildstore.h +++ b/src/zenstore/include/zenstore/buildstore/buildstore.h @@ -223,7 +223,7 @@ private: uint64_t m_MetaLogFlushPosition = 0; std::unique_ptr> m_TrackedBlobKeys; - std::atomic m_LastAccessTimeUpdateCount; + std::atomic m_LastAccessTimeUpdateCount{0}; friend class BuildStoreGcReferenceChecker; friend class BuildStoreGcReferencePruner; diff --git a/src/zenstore/include/zenstore/cache/cachedisklayer.h b/src/zenstore/include/zenstore/cache/cachedisklayer.h index 3d684587d..393e289ac 100644 --- a/src/zenstore/include/zenstore/cache/cachedisklayer.h +++ b/src/zenstore/include/zenstore/cache/cachedisklayer.h @@ -153,14 +153,14 @@ public: struct BucketStats { - uint64_t DiskSize; - uint64_t MemorySize; - uint64_t DiskHitCount; - uint64_t DiskMissCount; - uint64_t DiskWriteCount; - uint64_t MemoryHitCount; - uint64_t MemoryMissCount; - uint64_t MemoryWriteCount; + uint64_t DiskSize = 0; + uint64_t MemorySize = 0; + uint64_t DiskHitCount = 0; + uint64_t DiskMissCount = 0; + uint64_t DiskWriteCount = 0; + uint64_t MemoryHitCount = 0; + uint64_t MemoryMissCount = 0; + uint64_t MemoryWriteCount = 0; metrics::RequestStatsSnapshot PutOps; metrics::RequestStatsSnapshot GetOps; }; @@ -174,8 +174,8 @@ public: struct DiskStats { std::vector BucketStats; - uint64_t DiskSize; - uint64_t MemorySize; + uint64_t DiskSize = 0; + uint64_t MemorySize = 0; }; struct PutResult @@ -395,12 +395,12 @@ public: TCasLogFile m_SlogFile; uint64_t m_LogFlushPosition = 0; - std::atomic m_DiskHitCount; - std::atomic m_DiskMissCount; - std::atomic m_DiskWriteCount; - std::atomic m_MemoryHitCount; - std::atomic m_MemoryMissCount; - std::atomic m_MemoryWriteCount; + std::atomic m_DiskHitCount{0}; + std::atomic m_DiskMissCount{0}; + std::atomic m_DiskWriteCount{0}; + std::atomic m_MemoryHitCount{0}; + std::atomic m_MemoryMissCount{0}; + std::atomic m_MemoryWriteCount{0}; metrics::RequestStats m_PutOps; metrics::RequestStats m_GetOps; @@ -540,7 +540,7 @@ private: Configuration m_Configuration; std::atomic_uint64_t m_TotalMemCachedSize{}; std::atomic_bool m_IsMemCacheTrimming = false; - std::atomic m_NextAllowedTrimTick; + std::atomic m_NextAllowedTrimTick{}; mutable RwLock m_Lock; BucketMap_t m_Buckets; std::vector> m_DroppedBuckets; diff --git a/src/zenstore/include/zenstore/cache/cacheshared.h b/src/zenstore/include/zenstore/cache/cacheshared.h index 791720589..8e9cd7fd7 100644 --- a/src/zenstore/include/zenstore/cache/cacheshared.h +++ b/src/zenstore/include/zenstore/cache/cacheshared.h @@ -40,12 +40,12 @@ struct CacheValueDetails { struct ValueDetails { - uint64_t Size; - uint64_t RawSize; + uint64_t Size = 0; + uint64_t RawSize = 0; IoHash RawHash; GcClock::Tick LastAccess{}; std::vector Attachments; - ZenContentType ContentType; + ZenContentType ContentType = ZenContentType::kBinary; }; struct BucketDetails diff --git a/src/zenstore/include/zenstore/cache/structuredcachestore.h b/src/zenstore/include/zenstore/cache/structuredcachestore.h index 5a0a8b069..3722a0d31 100644 --- a/src/zenstore/include/zenstore/cache/structuredcachestore.h +++ b/src/zenstore/include/zenstore/cache/structuredcachestore.h @@ -70,9 +70,9 @@ public: struct NamespaceStats { - uint64_t HitCount; - uint64_t MissCount; - uint64_t WriteCount; + uint64_t HitCount = 0; + uint64_t MissCount = 0; + uint64_t WriteCount = 0; metrics::RequestStatsSnapshot PutOps; metrics::RequestStatsSnapshot GetOps; ZenCacheDiskLayer::DiskStats DiskStats; @@ -342,11 +342,11 @@ private: void LogWorker(); RwLock m_LogQueueLock; std::vector m_LogQueue; - std::atomic_bool m_ExitLogging; + std::atomic_bool m_ExitLogging{false}; Event m_LogEvent; std::thread m_AsyncLoggingThread; - std::atomic_bool m_WriteLogEnabled; - std::atomic_bool m_AccessLogEnabled; + std::atomic_bool m_WriteLogEnabled{false}; + std::atomic_bool m_AccessLogEnabled{false}; friend class CacheStoreReferenceChecker; }; diff --git a/src/zenstore/include/zenstore/caslog.h b/src/zenstore/include/zenstore/caslog.h index f3dd32fb1..7967d9dae 100644 --- a/src/zenstore/include/zenstore/caslog.h +++ b/src/zenstore/include/zenstore/caslog.h @@ -20,8 +20,8 @@ public: kTruncate }; - static bool IsValid(std::filesystem::path FileName, size_t RecordSize); - void Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode); + static bool IsValid(const std::filesystem::path& FileName, size_t RecordSize); + void Open(const std::filesystem::path& FileName, size_t RecordSize, Mode Mode); void Append(const void* DataPointer, uint64_t DataSize); void Replay(std::function&& Handler, uint64_t SkipEntryCount); void Flush(); @@ -48,7 +48,7 @@ private: static_assert(sizeof(FileHeader) == 64); private: - void Open(std::filesystem::path FileName, size_t RecordSize, BasicFile::Mode Mode); + void Open(const std::filesystem::path& FileName, size_t RecordSize, BasicFile::Mode Mode); BasicFile m_File; FileHeader m_Header; @@ -60,8 +60,8 @@ template class TCasLogFile : public CasLogFile { public: - static bool IsValid(std::filesystem::path FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); } - void Open(std::filesystem::path FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); } + static bool IsValid(const std::filesystem::path& FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); } + void Open(const std::filesystem::path& FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); } // This should be called before the Replay() is called to do some basic sanity checking bool Initialize() { return true; } diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h index 794f50d96..67cf852f9 100644 --- a/src/zenstore/include/zenstore/gc.h +++ b/src/zenstore/include/zenstore/gc.h @@ -443,8 +443,8 @@ struct GcSchedulerState uint64_t DiskFree = 0; GcClock::TimePoint LastFullGcTime{}; GcClock::TimePoint LastLightweightGcTime{}; - std::chrono::seconds RemainingTimeUntilLightweightGc; - std::chrono::seconds RemainingTimeUntilFullGc; + std::chrono::seconds RemainingTimeUntilLightweightGc{}; + std::chrono::seconds RemainingTimeUntilFullGc{}; uint64_t RemainingSpaceUntilFullGC = 0; std::chrono::milliseconds LastFullGcDuration{}; @@ -562,7 +562,7 @@ private: GcClock::TimePoint m_LastGcExpireTime{}; IoHash m_LastFullAttachmentRangeMin = IoHash::Zero; IoHash m_LastFullAttachmentRangeMax = IoHash::Max; - uint8_t m_AttachmentPassIndex; + uint8_t m_AttachmentPassIndex = 0; std::chrono::milliseconds m_LastFullGcDuration{}; GcStorageSize m_LastFullGCDiff; diff --git a/src/zenstore/include/zenstore/projectstore.h b/src/zenstore/include/zenstore/projectstore.h index 33ef996db..6f49cd024 100644 --- a/src/zenstore/include/zenstore/projectstore.h +++ b/src/zenstore/include/zenstore/projectstore.h @@ -67,8 +67,8 @@ public: struct OplogEntryAddress { - uint32_t Offset; // note: Multiple of m_OpsAlign! - uint32_t Size; + uint32_t Offset = 0; // note: Multiple of m_OpsAlign! + uint32_t Size = 0; }; struct OplogEntry @@ -80,11 +80,7 @@ public: uint32_t Reserved; inline bool IsTombstone() const { return OpCoreAddress.Offset == 0 && OpCoreAddress.Size == 0 && OpLsn.Number; } - inline void MakeTombstone() - { - OpLsn = {}; - OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0; - } + inline void MakeTombstone() { OpCoreAddress.Offset = OpCoreAddress.Size = OpCoreHash = Reserved = 0; } }; static_assert(IsPow2(sizeof(OplogEntry))); diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp index 217336eec..3f705d12c 100644 --- a/src/zenstore/projectstore.cpp +++ b/src/zenstore/projectstore.cpp @@ -1488,7 +1488,7 @@ ProjectStore::Oplog::Read() else { std::vector OpLogEntries; - uint64_t InvalidEntries; + uint64_t InvalidEntries = 0; m_Storage->ReadOplogEntriesFromLog(OpLogEntries, InvalidEntries, m_LogFlushPosition); for (const OplogEntry& OpEntry : OpLogEntries) { @@ -1750,8 +1750,8 @@ ProjectStore::Oplog::Validate(const std::filesystem::path& ProjectRootDir, } }; - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -2373,7 +2373,7 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P else if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end()) { CidChunkIndexes.push_back(ChunkIndex); - CidChunkHashes.push_back(ChunkIt->second); + CidChunkHashes.push_back(MetaIt->second); } else if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end()) { @@ -2384,8 +2384,8 @@ ProjectStore::Oplog::IterateChunks(const std::filesystem::path& P } if (OptionalWorkerPool) { - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -3817,7 +3817,7 @@ ProjectStore::Project::OpenOplog(std::string_view OplogId, bool AllowCompact, bo std::filesystem::path DeletePath; if (!RemoveOplog(OplogId, DeletePath)) { - ZEN_WARN("Failed to clean up deleted oplog {}/{}", Identifier, OplogId, OplogBasePath); + ZEN_WARN("Failed to clean up deleted oplog {}/{} at '{}'", Identifier, OplogId, OplogBasePath); } ReOpen = true; @@ -4053,8 +4053,8 @@ ProjectStore::Project::Scrub(ScrubContext& Ctx) RwLock::SharedLockScope _(m_ProjectLock); - std::atomic Abort; - std::atomic Pause; + std::atomic Abort{false}; + std::atomic Pause{false}; ParallelWork Work(Abort, Pause, WorkerThreadPool::EMode::DisableBacklog); try @@ -4433,8 +4433,8 @@ ProjectStore::Flush() } WorkerThreadPool& WorkerPool = GetSmallWorkerPool(EWorkloadType::Burst); - std::atomic AbortFlag; - std::atomic PauseFlag; + std::atomic AbortFlag{false}; + std::atomic PauseFlag{false}; ParallelWork Work(AbortFlag, PauseFlag, WorkerThreadPool::EMode::DisableBacklog); try { @@ -4974,7 +4974,7 @@ ProjectStore::GetProjectChunkInfos(LoggerRef InLog, Project& Project, Oplog& Opl } if (WantsRawSizeField) { - ZEN_ASSERT_SLOW(Sizes[Index] == (uint64_t)-1); + ZEN_ASSERT_SLOW(RawSizes[Index] == (uint64_t)-1); RawSizes[Index] = Payload.GetSize(); } } @@ -5762,7 +5762,7 @@ public: } } - for (auto ProjectIt : m_ProjectStore.m_Projects) + for (const auto& ProjectIt : m_ProjectStore.m_Projects) { Ref Project = ProjectIt.second; std::vector OplogsToCompact = Project->GetOplogsToCompact(); diff --git a/src/zenstore/workspaces.cpp b/src/zenstore/workspaces.cpp index df3cd31ef..ad21bbc68 100644 --- a/src/zenstore/workspaces.cpp +++ b/src/zenstore/workspaces.cpp @@ -383,7 +383,7 @@ Workspace::GetShares() const { std::vector> Shares; Shares.reserve(m_Shares.size()); - for (auto It : m_Shares) + for (const auto& It : m_Shares) { Shares.push_back(It.second); } @@ -435,7 +435,7 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId) Workspace = FindWorkspace(Lock, WorkspaceId); if (Workspace) { - for (auto Share : Workspace->GetShares()) + for (const auto& Share : Workspace->GetShares()) { DeletedShares.insert(Share->GetConfig().Id); } @@ -482,6 +482,12 @@ Workspaces::RefreshWorkspaceShares(const Oid& WorkspaceId) m_ShareAliases.erase(Share->GetConfig().Alias); } Workspace->SetShare(Configuration.Id, std::move(NewShare)); + if (!Configuration.Alias.empty()) + { + m_ShareAliases.insert_or_assign( + Configuration.Alias, + ShareAlias{.WorkspaceId = WorkspaceId, .ShareId = Configuration.Id}); + } } } else @@ -602,7 +608,7 @@ Workspaces::GetWorkspaceShareChunks(const Oid& WorkspaceId, { RequestedOffset = Size; } - if ((RequestedOffset + RequestedSize) > Size) + if (RequestedSize > Size - RequestedOffset) { RequestedSize = Size - RequestedOffset; } @@ -649,7 +655,7 @@ Workspaces::GetWorkspaces() const { std::vector Workspaces; RwLock::SharedLockScope Lock(m_Lock); - for (auto It : m_Workspaces) + for (const auto& It : m_Workspaces) { Workspaces.push_back(It.first); } @@ -679,7 +685,7 @@ Workspaces::GetWorkspaceShares(const Oid& WorkspaceId) const if (Workspace) { std::vector Shares; - for (auto Share : Workspace->GetShares()) + for (const auto& Share : Workspace->GetShares()) { Shares.push_back(Share->GetConfig().Id); } -- cgit v1.2.3 From 19a117889c2db6b817af9458c04c04f324162e75 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 9 Mar 2026 10:50:47 +0100 Subject: Eliminate spdlog dependency (#773) Removes the vendored spdlog library (~12,000 lines) and replaces it with a purpose-built logging system in zencore (~1,800 lines). The new implementation provides the same functionality with fewer abstractions, no shared_ptr overhead, and full control over the logging pipeline. ### What changed **New logging core in zencore/logging/:** - LogMessage, Formatter, Sink, Logger, Registry - core abstractions matching spdlog's model but simplified - AnsiColorStdoutSink - ANSI color console output (replaces spdlog stdout_color_sink) - MsvcSink - OutputDebugString on Windows (replaces spdlog msvc_sink) - AsyncSink - async logging via BlockingQueue worker thread (replaces spdlog async_logger) - NullSink, MessageOnlyFormatter - utility types - Thread-safe timestamp caching in formatters using RwLock **Moved to zenutil/logging/:** - FullFormatter - full log formatting with timestamp, logger name, level, source location, multiline alignment - JsonFormatter - structured JSON log output - RotatingFileSink - rotating file sink with atomic size tracking **API changes:** - Log levels are now an enum (LogLevel) instead of int, eliminating the zen::logging::level namespace - LoggerRef no longer wraps shared_ptr - it holds a raw pointer with the registry owning lifetime - Logger error handler is wired through Registry and propagated to all loggers on registration - Logger::Log() now populates ThreadId on every message **Cleanup:** - Deleted thirdparty/spdlog/ entirely (110+ files) - Deleted full_test_formatter (was ~80% duplicate of FullFormatter) - Renamed snake_case classes to PascalCase (full_formatter -> FullFormatter, json_formatter -> JsonFormatter, sentry_sink -> SentrySink) - Removed spdlog from xmake dependency graph ### Build / test impact - zencore no longer depends on spdlog - zenutil and zenvfs xmake.lua updated to drop spdlog dep - zentelemetry xmake.lua updated to drop spdlog dep - All existing tests pass, no test changes required beyond formatter class renames --- src/zen/cmds/builds_cmd.cpp | 8 +- src/zen/cmds/wipe_cmd.cpp | 8 +- src/zen/progressbar.cpp | 12 +- src/zen/zen.cpp | 1 - src/zencore/include/zencore/blockingqueue.h | 2 + src/zencore/include/zencore/logbase.h | 113 ++++--- src/zencore/include/zencore/logging.h | 214 ++++++------- .../include/zencore/logging/ansicolorsink.h | 26 ++ src/zencore/include/zencore/logging/asyncsink.h | 30 ++ src/zencore/include/zencore/logging/formatter.h | 20 ++ src/zencore/include/zencore/logging/helpers.h | 122 ++++++++ src/zencore/include/zencore/logging/logger.h | 63 ++++ src/zencore/include/zencore/logging/logmsg.h | 66 +++++ src/zencore/include/zencore/logging/memorybuffer.h | 11 + .../include/zencore/logging/messageonlyformatter.h | 22 ++ src/zencore/include/zencore/logging/msvcsink.h | 30 ++ src/zencore/include/zencore/logging/nullsink.h | 17 ++ src/zencore/include/zencore/logging/registry.h | 70 +++++ src/zencore/include/zencore/logging/sink.h | 34 +++ src/zencore/include/zencore/logging/tracesink.h | 23 ++ src/zencore/include/zencore/sentryintegration.h | 8 +- src/zencore/logging.cpp | 328 +++++++++----------- src/zencore/logging/ansicolorsink.cpp | 178 +++++++++++ src/zencore/logging/asyncsink.cpp | 212 +++++++++++++ src/zencore/logging/logger.cpp | 142 +++++++++ src/zencore/logging/msvcsink.cpp | 80 +++++ src/zencore/logging/registry.cpp | 330 +++++++++++++++++++++ src/zencore/logging/tracesink.cpp | 88 ++++++ src/zencore/sentryintegration.cpp | 128 ++++---- src/zencore/testing.cpp | 4 +- src/zencore/xmake.lua | 1 - src/zencore/zencore.cpp | 2 +- src/zenhttp/servers/httpasio.cpp | 4 +- src/zenhttp/servers/httpplugin.cpp | 4 +- src/zenhttp/transports/dlltransport.cpp | 38 ++- .../include/zenremotestore/operationlogoutput.h | 24 +- src/zenremotestore/operationlogoutput.cpp | 14 +- .../projectstore/remoteprojectstore.cpp | 5 +- src/zenserver-test/logging-tests.cpp | 2 +- src/zenserver-test/zenserver-test.cpp | 8 +- src/zenserver/diag/diagsvcs.cpp | 6 +- src/zenserver/diag/logging.cpp | 51 ++-- src/zenserver/diag/otlphttp.cpp | 4 +- src/zenserver/diag/otlphttp.h | 15 +- src/zenserver/main.cpp | 2 +- src/zenserver/storage/admin/admin.cpp | 6 +- src/zenstore/projectstore.cpp | 2 +- .../include/zentelemetry/otlpencoder.h | 8 +- src/zentelemetry/otlpencoder.cpp | 44 +-- src/zentelemetry/xmake.lua | 2 +- src/zenutil/config/commandlineoptions.cpp | 1 + src/zenutil/config/loggingconfig.cpp | 22 +- src/zenutil/include/zenutil/config/loggingconfig.h | 2 +- src/zenutil/include/zenutil/logging.h | 11 +- .../include/zenutil/logging/fullformatter.h | 89 +++--- .../include/zenutil/logging/jsonformatter.h | 168 +++++------ .../include/zenutil/logging/rotatingfilesink.h | 89 +++--- .../include/zenutil/logging/testformatter.h | 160 ---------- src/zenutil/logging.cpp | 144 +++++---- src/zenutil/xmake.lua | 2 +- src/zenvfs/xmake.lua | 2 +- 61 files changed, 2297 insertions(+), 1025 deletions(-) create mode 100644 src/zencore/include/zencore/logging/ansicolorsink.h create mode 100644 src/zencore/include/zencore/logging/asyncsink.h create mode 100644 src/zencore/include/zencore/logging/formatter.h create mode 100644 src/zencore/include/zencore/logging/helpers.h create mode 100644 src/zencore/include/zencore/logging/logger.h create mode 100644 src/zencore/include/zencore/logging/logmsg.h create mode 100644 src/zencore/include/zencore/logging/memorybuffer.h create mode 100644 src/zencore/include/zencore/logging/messageonlyformatter.h create mode 100644 src/zencore/include/zencore/logging/msvcsink.h create mode 100644 src/zencore/include/zencore/logging/nullsink.h create mode 100644 src/zencore/include/zencore/logging/registry.h create mode 100644 src/zencore/include/zencore/logging/sink.h create mode 100644 src/zencore/include/zencore/logging/tracesink.h create mode 100644 src/zencore/logging/ansicolorsink.cpp create mode 100644 src/zencore/logging/asyncsink.cpp create mode 100644 src/zencore/logging/logger.cpp create mode 100644 src/zencore/logging/msvcsink.cpp create mode 100644 src/zencore/logging/registry.cpp create mode 100644 src/zencore/logging/tracesink.cpp delete mode 100644 src/zenutil/include/zenutil/logging/testformatter.h (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index 0722e9714..e5cbafbea 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -269,10 +269,10 @@ namespace builds_impl { static ProgressBar::Mode ProgressMode = ProgressBar::Mode::Pretty; #undef ZEN_CONSOLE_VERBOSE -#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ - if (IsVerbose) \ - { \ - ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \ +#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ + if (IsVerbose) \ + { \ + ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \ } const std::string DefaultAccessTokenEnvVariableName( diff --git a/src/zen/cmds/wipe_cmd.cpp b/src/zen/cmds/wipe_cmd.cpp index fd9e28a80..10f5ad8e1 100644 --- a/src/zen/cmds/wipe_cmd.cpp +++ b/src/zen/cmds/wipe_cmd.cpp @@ -50,10 +50,10 @@ namespace wipe_impl { } #undef ZEN_CONSOLE_VERBOSE -#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ - if (IsVerbose) \ - { \ - ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__); \ +#define ZEN_CONSOLE_VERBOSE(fmtstr, ...) \ + if (IsVerbose) \ + { \ + ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__); \ } static void SignalCallbackHandler(int SigNum) diff --git a/src/zen/progressbar.cpp b/src/zen/progressbar.cpp index 9467ed60d..b758c061b 100644 --- a/src/zen/progressbar.cpp +++ b/src/zen/progressbar.cpp @@ -390,19 +390,19 @@ class ConsoleOpLogOutput : public OperationLogOutput { public: ConsoleOpLogOutput(zen::ProgressBar::Mode InMode) : m_Mode(InMode) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override { - logging::EmitConsoleLogMessage(LogLevel, Format, Args); + logging::EmitConsoleLogMessage(Point, Args); } - virtual void SetLogOperationName(std::string_view Name) { zen::ProgressBar::SetLogOperationName(m_Mode, Name); } - virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) + virtual void SetLogOperationName(std::string_view Name) override { zen::ProgressBar::SetLogOperationName(m_Mode, Name); } + virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { zen::ProgressBar::SetLogOperationProgress(m_Mode, StepIndex, StepCount); } - virtual uint32_t GetProgressUpdateDelayMS() { return GetUpdateDelayMS(m_Mode); } + virtual uint32_t GetProgressUpdateDelayMS() override { return GetUpdateDelayMS(m_Mode); } - virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); } + virtual ProgressBar* CreateProgressBar(std::string_view InSubTask) override { return new ConsoleOpLogProgressBar(m_Mode, InSubTask); } private: zen::ProgressBar::Mode m_Mode; diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index ba8a76bc3..7f7afa322 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -689,7 +689,6 @@ main(int argc, char** argv) const LoggingOptions LogOptions = {.IsDebug = GlobalOptions.IsDebug, .IsVerbose = GlobalOptions.IsVerbose, .IsTest = false, - .AllowAsync = false, .NoConsoleOutput = GlobalOptions.LoggingConfig.NoConsoleOutput, .QuietConsole = GlobalOptions.LoggingConfig.QuietConsole, .AbsLogFile = GlobalOptions.LoggingConfig.AbsLogFile, diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h index e91fdc659..b6c93e937 100644 --- a/src/zencore/include/zencore/blockingqueue.h +++ b/src/zencore/include/zencore/blockingqueue.h @@ -2,6 +2,8 @@ #pragma once +#include // For ZEN_ASSERT + #include #include #include diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h index 00af68b0a..ece17a85e 100644 --- a/src/zencore/include/zencore/logbase.h +++ b/src/zencore/include/zencore/logbase.h @@ -4,96 +4,85 @@ #include -#define ZEN_LOG_LEVEL_TRACE 0 -#define ZEN_LOG_LEVEL_DEBUG 1 -#define ZEN_LOG_LEVEL_INFO 2 -#define ZEN_LOG_LEVEL_WARN 3 -#define ZEN_LOG_LEVEL_ERROR 4 -#define ZEN_LOG_LEVEL_CRITICAL 5 -#define ZEN_LOG_LEVEL_OFF 6 - -#define ZEN_LEVEL_NAME_TRACE std::string_view("trace", 5) -#define ZEN_LEVEL_NAME_DEBUG std::string_view("debug", 5) -#define ZEN_LEVEL_NAME_INFO std::string_view("info", 4) -#define ZEN_LEVEL_NAME_WARNING std::string_view("warning", 7) -#define ZEN_LEVEL_NAME_ERROR std::string_view("error", 5) -#define ZEN_LEVEL_NAME_CRITICAL std::string_view("critical", 8) -#define ZEN_LEVEL_NAME_OFF std::string_view("off", 3) - -namespace zen::logging::level { +namespace zen::logging { enum LogLevel : int { - Trace = ZEN_LOG_LEVEL_TRACE, - Debug = ZEN_LOG_LEVEL_DEBUG, - Info = ZEN_LOG_LEVEL_INFO, - Warn = ZEN_LOG_LEVEL_WARN, - Err = ZEN_LOG_LEVEL_ERROR, - Critical = ZEN_LOG_LEVEL_CRITICAL, - Off = ZEN_LOG_LEVEL_OFF, + Trace, + Debug, + Info, + Warn, + Err, + Critical, + Off, LogLevelCount }; LogLevel ParseLogLevelString(std::string_view String); std::string_view ToStringView(LogLevel Level); -} // namespace zen::logging::level - -namespace zen::logging { - -void SetLogLevel(level::LogLevel NewLogLevel); -level::LogLevel GetLogLevel(); +void SetLogLevel(LogLevel NewLogLevel); +LogLevel GetLogLevel(); -} // namespace zen::logging +struct SourceLocation +{ + constexpr SourceLocation() = default; + constexpr SourceLocation(const char* InFilename, int InLine) : Filename(InFilename), Line(InLine) {} -namespace spdlog { -class logger; -} + constexpr operator bool() const noexcept { return Line != 0; } -namespace zen::logging { + const char* Filename{nullptr}; + int Line{0}; +}; -struct SourceLocation +/** This encodes the constant parts of a log message which can be emitted once + * and then referred to by log events. + * + * It's *critical* that instances of this struct are permanent and never + * destroyed, as log messages will refer to them by pointer. The easiest way + * to ensure this is to create them as function-local statics. + * + * The logging macros already do this for you so this should not be something + * you normally would need to worry about. + */ +struct LogPoint { - constexpr SourceLocation() = default; - constexpr SourceLocation(const char* filename_in, int line_in, const char* funcname_in) - : filename(filename_in) - , line(line_in) - , funcname(funcname_in) - { - } - - constexpr bool empty() const noexcept { return line == 0; } - - // IMPORTANT NOTE: the layout of this class must match the spdlog::source_loc class - // since we currently pass a pointer to it into spdlog after casting it to - // spdlog::source_loc* - // - // This is intended to be an intermediate state, before we (probably) transition off - // spdlog entirely - - const char* filename{nullptr}; - int line{0}; - const char* funcname{nullptr}; + SourceLocation Location; + LogLevel Level; + std::string_view FormatString; }; +class Logger; + } // namespace zen::logging namespace zen { +// Lightweight non-owning handle to a Logger. Loggers are owned by the Registry +// via Ref; LoggerRef exists as a cheap (raw pointer) handle that can be +// stored in members and passed through logging macros without requiring the +// complete Logger type or incurring refcount overhead on every log call. struct LoggerRef { LoggerRef() = default; - LoggerRef(spdlog::logger& InLogger) : SpdLogger(&InLogger) {} + LoggerRef(logging::Logger& InLogger) : m_Logger(&InLogger) {} + // This exists so that logging macros can pass LoggerRef or LogCategory + // to ZEN_LOG without needing to know which one it is LoggerRef Logger() { return *this; } - bool ShouldLog(int Level) const; - inline operator bool() const { return SpdLogger != nullptr; } + bool ShouldLog(logging::LogLevel Level) const; + inline operator bool() const { return m_Logger != nullptr; } + + inline logging::Logger* operator->() const { return m_Logger; } + inline logging::Logger& operator*() const { return *m_Logger; } - void SetLogLevel(logging::level::LogLevel NewLogLevel); - logging::level::LogLevel GetLogLevel(); + void SetLogLevel(logging::LogLevel NewLogLevel); + logging::LogLevel GetLogLevel(); + void Flush(); - spdlog::logger* SpdLogger = nullptr; +private: + logging::Logger* m_Logger = nullptr; }; } // namespace zen diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h index 74a44d028..4b593c19e 100644 --- a/src/zencore/include/zencore/logging.h +++ b/src/zencore/include/zencore/logging.h @@ -9,16 +9,9 @@ #if ZEN_PLATFORM_WINDOWS # define ZEN_LOG_SECTION(Id) ZEN_DATA_SECTION(Id) -# pragma section(".zlog$f", read) # pragma section(".zlog$l", read) -# pragma section(".zlog$m", read) -# pragma section(".zlog$s", read) -# define ZEN_DECLARE_FUNCTION static constinit ZEN_LOG_SECTION(".zlog$f") char FuncName[] = __FUNCTION__; -# define ZEN_LOG_FUNCNAME FuncName #else # define ZEN_LOG_SECTION(Id) -# define ZEN_DECLARE_FUNCTION -# define ZEN_LOG_FUNCNAME static_cast(__func__) #endif namespace zen::logging { @@ -37,34 +30,29 @@ LoggerRef ErrorLog(); void SetErrorLog(std::string_view LoggerId); LoggerRef Get(std::string_view Name); -void ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers); +void ConfigureLogLevels(LogLevel Level, std::string_view Loggers); void RefreshLogLevels(); -void RefreshLogLevels(level::LogLevel DefaultLevel); - +void RefreshLogLevels(LogLevel DefaultLevel); + +/** LogCategory allows for the creation of log categories that can be used with + * the logging macros just like a logger reference. The main purpose of this is + * to allow for static log categories in global scope where we can't actually + * go ahead and instantiate a logger immediately because the logging system may + * not be initialized yet. + */ struct LogCategory { - inline LogCategory(std::string_view InCategory) : CategoryName(InCategory) {} - - inline zen::LoggerRef Logger() - { - if (LoggerRef) - { - return LoggerRef; - } + inline LogCategory(std::string_view InCategory) : m_CategoryName(InCategory) {} - LoggerRef = zen::logging::Get(CategoryName); - return LoggerRef; - } + LoggerRef Logger(); - std::string CategoryName; - zen::LoggerRef LoggerRef; +private: + std::string m_CategoryName; + LoggerRef m_LoggerRef; }; -void EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args); -void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Message); -void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Message); -void EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args); -void EmitLogMessage(LoggerRef& Logger, const SourceLocation& Location, int LogLevel, std::string_view Format, fmt::format_args Args); +void EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args); +void EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args); template auto @@ -79,15 +67,14 @@ namespace zen { extern LoggerRef TheDefaultLogger; -inline LoggerRef -Log() -{ - if (TheDefaultLogger) - { - return TheDefaultLogger; - } - return zen::logging::ConsoleLog(); -} +/** + * This is the default logger, which any ZEN_INFO et al will get if there's + * no Log() function declared in the current scope. + * + * Typically, classes which want to log to its own channel will declare a Log() + * member function which returns a LoggerRef created at construction time. + */ +LoggerRef Log(); using logging::ConsoleLog; using logging::ErrorLog; @@ -98,12 +85,6 @@ using zen::ConsoleLog; using zen::ErrorLog; using zen::Log; -inline consteval bool -LogIsErrorLevel(int LogLevel) -{ - return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical); -}; - #if ZEN_BUILD_DEBUG # define ZEN_CHECK_FORMAT_STRING(fmtstr, ...) \ while (false) \ @@ -117,75 +98,66 @@ LogIsErrorLevel(int LogLevel) } #endif -#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_DECLARE_FUNCTION \ - static constinit ZEN_LOG_SECTION(".zlog$s") char FileName[] = __FILE__; \ - static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \ - static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::SourceLocation Location{FileName, __LINE__, ZEN_LOG_FUNCNAME}; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, \ - Location, \ - InLevel, \ - std::string_view(FormatString, sizeof FormatString - 1), \ - zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG_WITH_LOCATION(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") \ + zen::logging::LogPoint LogPoint{zen::logging::SourceLocation{__FILE__, __LINE__}, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); -#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - static constinit ZEN_LOG_SECTION(".zlog$m") char FormatString[] = fmtstr; \ - zen::LoggerRef Logger = InLogger; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - if (Logger.ShouldLog(InLevel)) \ - { \ - zen::logging::EmitLogMessage(Logger, \ - InLevel, \ - std::string_view(FormatString, sizeof FormatString - 1), \ - zen::logging::LogCaptureArguments(__VA_ARGS__)); \ - } \ +#define ZEN_LOG(InLogger, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + zen::LoggerRef Logger = InLogger; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + if (Logger.ShouldLog(InLevel)) \ + { \ + zen::logging::EmitLogMessage(Logger, LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ + } \ } while (false); #define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ static zen::logging::LogCategory Category { Name } -#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \ - ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) - -#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) - -#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - zen::logging::EmitConsoleLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ +#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_ERROR(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) ZEN_LOG_WITH_LOCATION(Category.Logger(), zen::logging::Critical, fmtstr, ##__VA_ARGS__) + +#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_ERROR(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_CRITICAL(fmtstr, ...) ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr, ##__VA_ARGS__) + +#define ZEN_CONSOLE_LOG(InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit ZEN_LOG_SECTION(".zlog$l") zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + zen::logging::EmitConsoleLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ } while (false) -#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Trace, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Err, fmtstr, ##__VA_ARGS__) -#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::level::Critical, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_TRACE(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Trace, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_DEBUG(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_INFO(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_WARN(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_ERROR(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Err, fmtstr, ##__VA_ARGS__) +#define ZEN_CONSOLE_CRITICAL(fmtstr, ...) ZEN_CONSOLE_LOG(zen::logging::Critical, fmtstr, ##__VA_ARGS__) ////////////////////////////////////////////////////////////////////////// @@ -240,28 +212,28 @@ std::string_view EmitActivitiesForLogging(StringBuilderBase& OutString); #define ZEN_LOG_SCOPE(...) ScopedLazyActivity $Activity##__LINE__([&](StringBuilderBase& Out) { Out << fmt::format(__VA_ARGS__); }) -#define ZEN_SCOPED_WARN(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG(Log(), zen::logging::level::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_WARN(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG(Log(), zen::logging::Warn, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) -#define ZEN_SCOPED_ERROR(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_ERROR(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Err, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) -#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \ - do \ - { \ - ExtendableStringBuilder<256> ScopeString; \ - const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ - ZEN_LOG_WITH_LOCATION(Log(), zen::logging::level::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \ +#define ZEN_SCOPED_CRITICAL(fmtstr, ...) \ + do \ + { \ + ExtendableStringBuilder<256> ScopeString; \ + const std::string_view Scopes = EmitActivitiesForLogging(ScopeString); \ + ZEN_LOG_WITH_LOCATION(Log(), zen::logging::Critical, fmtstr "{}", ##__VA_ARGS__, Scopes); \ } while (false) ScopedActivityBase* GetThreadActivity(); diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h new file mode 100644 index 000000000..9f859e8d7 --- /dev/null +++ b/src/zencore/include/zencore/logging/ansicolorsink.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include + +namespace zen::logging { + +class AnsiColorStdoutSink : public Sink +{ +public: + AnsiColorStdoutSink(); + ~AnsiColorStdoutSink() override; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr InFormatter) override; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/asyncsink.h b/src/zencore/include/zencore/logging/asyncsink.h new file mode 100644 index 000000000..c49a1ccce --- /dev/null +++ b/src/zencore/include/zencore/logging/asyncsink.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include + +namespace zen::logging { + +class AsyncSink : public Sink +{ +public: + explicit AsyncSink(std::vector InSinks); + ~AsyncSink() override; + + AsyncSink(const AsyncSink&) = delete; + AsyncSink& operator=(const AsyncSink&) = delete; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr InFormatter) override; + +private: + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h new file mode 100644 index 000000000..11904d71d --- /dev/null +++ b/src/zencore/include/zencore/logging/formatter.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include + +namespace zen::logging { + +class Formatter +{ +public: + virtual ~Formatter() = default; + virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0; + virtual std::unique_ptr Clone() const = 0; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h new file mode 100644 index 000000000..ce021e1a5 --- /dev/null +++ b/src/zencore/include/zencore/logging/helpers.h @@ -0,0 +1,122 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +#include +#include +#include + +namespace zen::logging::helpers { + +inline void +AppendStringView(std::string_view Sv, MemoryBuffer& Dest) +{ + Dest.append(Sv.data(), Sv.data() + Sv.size()); +} + +inline void +AppendInt(int N, MemoryBuffer& Dest) +{ + fmt::format_int Formatted(N); + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); +} + +inline void +Pad2(int N, MemoryBuffer& Dest) +{ + if (N >= 0 && N < 100) + { + Dest.push_back(static_cast('0' + N / 10)); + Dest.push_back(static_cast('0' + N % 10)); + } + else + { + fmt::format_int Formatted(N); + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); + } +} + +inline void +Pad3(uint32_t N, MemoryBuffer& Dest) +{ + if (N < 1000) + { + Dest.push_back(static_cast('0' + N / 100)); + Dest.push_back(static_cast('0' + (N / 10) % 10)); + Dest.push_back(static_cast('0' + N % 10)); + } + else + { + AppendInt(static_cast(N), Dest); + } +} + +inline void +PadUint(size_t N, unsigned int Width, MemoryBuffer& Dest) +{ + fmt::format_int Formatted(N); + auto StrLen = static_cast(Formatted.size()); + if (Width > StrLen) + { + for (unsigned int Pad = 0; Pad < Width - StrLen; ++Pad) + { + Dest.push_back('0'); + } + } + Dest.append(Formatted.data(), Formatted.data() + Formatted.size()); +} + +template +inline ToDuration +TimeFraction(std::chrono::system_clock::time_point Tp) +{ + using std::chrono::duration_cast; + using std::chrono::seconds; + auto Duration = Tp.time_since_epoch(); + auto Secs = duration_cast(Duration); + return duration_cast(Duration) - duration_cast(Secs); +} + +inline std::tm +SafeLocaltime(std::time_t Time) +{ + std::tm Result{}; +#if defined(_WIN32) + localtime_s(&Result, &Time); +#else + localtime_r(&Time, &Result); +#endif + return Result; +} + +inline const char* +ShortFilename(const char* Path) +{ + if (Path == nullptr) + { + return Path; + } + + const char* It = Path; + const char* LastSep = Path; + while (*It) + { + if (*It == '/' || *It == '\\') + { + LastSep = It + 1; + } + ++It; + } + return LastSep; +} + +inline std::string_view +LevelToShortString(LogLevel Level) +{ + return ToStringView(Level); +} + +} // namespace zen::logging::helpers diff --git a/src/zencore/include/zencore/logging/logger.h b/src/zencore/include/zencore/logging/logger.h new file mode 100644 index 000000000..39d1139a5 --- /dev/null +++ b/src/zencore/include/zencore/logging/logger.h @@ -0,0 +1,63 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include + +namespace zen::logging { + +class ErrorHandler +{ +public: + virtual ~ErrorHandler() = default; + virtual void HandleError(const std::string_view& Msg) = 0; +}; + +class Logger : public RefCounted +{ +public: + Logger(std::string_view InName, SinkPtr InSink); + Logger(std::string_view InName, std::span InSinks); + ~Logger(); + + Logger(const Logger&) = delete; + Logger& operator=(const Logger&) = delete; + + void Log(const LogPoint& Point, fmt::format_args Args); + + bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); } + + void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); } + LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); } + + void SetFlushLevel(LogLevel InLevel) { m_FlushLevel.store(InLevel, std::memory_order_relaxed); } + LogLevel GetFlushLevel() const { return m_FlushLevel.load(std::memory_order_relaxed); } + + std::string_view Name() const; + + void SetSinks(std::vector InSinks); + void AddSink(SinkPtr InSink); + + void SetFormatter(std::unique_ptr InFormatter); + + void SetErrorHandler(ErrorHandler* Handler); + + void Flush(); + + Ref Clone(std::string_view NewName) const; + +private: + void SinkIt(const LogMessage& Msg); + void FlushIfNeeded(LogLevel InLevel); + + struct Impl; + std::unique_ptr m_Impl; + std::atomic m_Level{Info}; + std::atomic m_FlushLevel{Off}; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h new file mode 100644 index 000000000..1d8b6b1b7 --- /dev/null +++ b/src/zencore/include/zencore/logging/logmsg.h @@ -0,0 +1,66 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include + +namespace zen::logging { + +using LogClock = std::chrono::system_clock; + +struct LogMessage +{ + LogMessage() = default; + + LogMessage(const LogPoint& InPoint, std::string_view InLoggerName, std::string_view InPayload) + : m_LoggerName(InLoggerName) + , m_Level(InPoint.Level) + , m_Time(LogClock::now()) + , m_Source(InPoint.Location) + , m_Payload(InPayload) + , m_Point(&InPoint) + { + } + + std::string_view GetPayload() const { return m_Payload; } + int GetThreadId() const { return m_ThreadId; } + LogClock::time_point GetTime() const { return m_Time; } + LogLevel GetLevel() const { return m_Level; } + std::string_view GetLoggerName() const { return m_LoggerName; } + const SourceLocation& GetSource() const { return m_Source; } + const LogPoint& GetLogPoint() const { return *m_Point; } + + void SetThreadId(int InThreadId) { m_ThreadId = InThreadId; } + void SetPayload(std::string_view InPayload) { m_Payload = InPayload; } + void SetLoggerName(std::string_view InName) { m_LoggerName = InName; } + void SetLevel(LogLevel InLevel) { m_Level = InLevel; } + void SetTime(LogClock::time_point InTime) { m_Time = InTime; } + void SetSource(const SourceLocation& InSource) { m_Source = InSource; } + + mutable size_t ColorRangeStart = 0; + mutable size_t ColorRangeEnd = 0; + +private: + static constexpr LogPoint s_DefaultPoints[LogLevelCount] = { + {{}, Trace, {}}, + {{}, Debug, {}}, + {{}, Info, {}}, + {{}, Warn, {}}, + {{}, Err, {}}, + {{}, Critical, {}}, + {{}, Off, {}}, + }; + + std::string_view m_LoggerName; + LogLevel m_Level = Off; + std::chrono::system_clock::time_point m_Time; + SourceLocation m_Source; + std::string_view m_Payload; + const LogPoint* m_Point = &s_DefaultPoints[Off]; + int m_ThreadId = 0; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/memorybuffer.h b/src/zencore/include/zencore/logging/memorybuffer.h new file mode 100644 index 000000000..cd0ff324f --- /dev/null +++ b/src/zencore/include/zencore/logging/memorybuffer.h @@ -0,0 +1,11 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen::logging { + +using MemoryBuffer = fmt::basic_memory_buffer; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/messageonlyformatter.h b/src/zencore/include/zencore/logging/messageonlyformatter.h new file mode 100644 index 000000000..ce25fe9a6 --- /dev/null +++ b/src/zencore/include/zencore/logging/messageonlyformatter.h @@ -0,0 +1,22 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include + +namespace zen::logging { + +class MessageOnlyFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr Clone() const override { return std::make_unique(); } +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/msvcsink.h b/src/zencore/include/zencore/logging/msvcsink.h new file mode 100644 index 000000000..48ea1b915 --- /dev/null +++ b/src/zencore/include/zencore/logging/msvcsink.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_PLATFORM_WINDOWS + +# include + +namespace zen::logging { + +class MsvcSink : public Sink +{ +public: + MsvcSink(); + ~MsvcSink() override = default; + + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr InFormatter) override; + +private: + std::mutex m_Mutex; + std::unique_ptr m_Formatter; +}; + +} // namespace zen::logging + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/include/zencore/logging/nullsink.h b/src/zencore/include/zencore/logging/nullsink.h new file mode 100644 index 000000000..7ac5677c6 --- /dev/null +++ b/src/zencore/include/zencore/logging/nullsink.h @@ -0,0 +1,17 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen::logging { + +class NullSink : public Sink +{ +public: + void Log(const LogMessage& /*Msg*/) override {} + void Flush() override {} + void SetFormatter(std::unique_ptr /*InFormatter*/) override {} +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/registry.h b/src/zencore/include/zencore/logging/registry.h new file mode 100644 index 000000000..a4d3692d2 --- /dev/null +++ b/src/zencore/include/zencore/logging/registry.h @@ -0,0 +1,70 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace zen::logging { + +class Registry +{ +public: + using LogLevels = std::span>; + + static Registry& Instance(); + void Shutdown(); + + void Register(Ref InLogger); + void Drop(const std::string& Name); + Ref Get(const std::string& Name); + + void SetDefaultLogger(Ref InLogger); + Logger* DefaultLoggerRaw(); + Ref DefaultLogger(); + + void SetGlobalLevel(LogLevel Level); + LogLevel GetGlobalLevel() const; + void SetLevels(LogLevels Levels, LogLevel* DefaultLevel); + + void FlushAll(); + void FlushOn(LogLevel Level); + void FlushEvery(std::chrono::seconds Interval); + + // Change formatter on all registered loggers + void SetFormatter(std::unique_ptr InFormatter); + + // Apply function to all registered loggers. Note that the function will + // be called while the registry mutex is held, so it should be fast and + // not attempt to call back into the registry. + template + void ApplyAll(Func&& F) + { + ApplyAllImpl([](void* Ctx, Ref L) { (*static_cast*>(Ctx))(std::move(L)); }, &F); + } + + // Set error handler for all loggers in the registry. The handler is called + // if any logger encounters an error during logging or flushing. + // The caller must ensure the handler outlives the registry. + void SetErrorHandler(ErrorHandler* Handler); + +private: + void ApplyAllImpl(void (*Func)(void*, Ref), void* Context); + + Registry(); + ~Registry(); + + Registry(const Registry&) = delete; + Registry& operator=(const Registry&) = delete; + + struct Impl; + std::unique_ptr m_Impl; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/sink.h b/src/zencore/include/zencore/logging/sink.h new file mode 100644 index 000000000..172176a4e --- /dev/null +++ b/src/zencore/include/zencore/logging/sink.h @@ -0,0 +1,34 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include +#include +#include + +#include +#include + +namespace zen::logging { + +class Sink : public RefCounted +{ +public: + virtual ~Sink() = default; + + virtual void Log(const LogMessage& Msg) = 0; + virtual void Flush() = 0; + + virtual void SetFormatter(std::unique_ptr InFormatter) = 0; + + bool ShouldLog(LogLevel InLevel) const { return InLevel >= m_Level.load(std::memory_order_relaxed); } + void SetLevel(LogLevel InLevel) { m_Level.store(InLevel, std::memory_order_relaxed); } + LogLevel GetLevel() const { return m_Level.load(std::memory_order_relaxed); } + +protected: + std::atomic m_Level{Trace}; +}; + +using SinkPtr = Ref; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/tracesink.h b/src/zencore/include/zencore/logging/tracesink.h new file mode 100644 index 000000000..e63d838b4 --- /dev/null +++ b/src/zencore/include/zencore/logging/tracesink.h @@ -0,0 +1,23 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +namespace zen::logging { + +/** + * A logging sink that forwards log messages to the trace system. + * + * Work-in-progress, not fully implemented. + */ + +class TraceSink : public Sink +{ +public: + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr InFormatter) override; +}; + +} // namespace zen::logging diff --git a/src/zencore/include/zencore/sentryintegration.h b/src/zencore/include/zencore/sentryintegration.h index faf1238b7..a4e33d69e 100644 --- a/src/zencore/include/zencore/sentryintegration.h +++ b/src/zencore/include/zencore/sentryintegration.h @@ -11,11 +11,9 @@ #if ZEN_USE_SENTRY -# include +# include -ZEN_THIRD_PARTY_INCLUDES_START -# include -ZEN_THIRD_PARTY_INCLUDES_END +# include namespace sentry { @@ -53,7 +51,7 @@ private: std::string m_SentryUserName; std::string m_SentryHostName; std::string m_SentryId; - std::shared_ptr m_SentryLogger; + Ref m_SentryLogger; }; } // namespace zen diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp index ebd68de09..099518637 100644 --- a/src/zencore/logging.cpp +++ b/src/zencore/logging.cpp @@ -2,208 +2,128 @@ #include "zencore/logging.h" +#include +#include +#include +#include +#include #include #include #include #include -ZEN_THIRD_PARTY_INCLUDES_START -#include -#include -#include -#include -ZEN_THIRD_PARTY_INCLUDES_END +#include #if ZEN_PLATFORM_WINDOWS # pragma section(".zlog$a", read) -# pragma section(".zlog$f", read) -# pragma section(".zlog$m", read) -# pragma section(".zlog$s", read) +# pragma section(".zlog$l", read) # pragma section(".zlog$z", read) #endif namespace zen { -// We shadow the underlying spdlog default logger, in order to avoid a bunch of overhead LoggerRef TheDefaultLogger; +LoggerRef +Log() +{ + if (TheDefaultLogger) + { + return TheDefaultLogger; + } + return zen::logging::ConsoleLog(); +} + } // namespace zen namespace zen::logging { -using MemoryBuffer_t = fmt::basic_memory_buffer; - -struct LoggingContext -{ - inline LoggingContext(); - inline ~LoggingContext(); - - zen::logging::MemoryBuffer_t MessageBuffer; - - inline std::string_view Message() const { return std::string_view(MessageBuffer.data(), MessageBuffer.size()); } -}; +////////////////////////////////////////////////////////////////////////// -LoggingContext::LoggingContext() +LoggerRef +LogCategory::Logger() { -} + // This should be thread safe since zen::logging::Get() will return + // the same logger instance for the same category name. Also the + // LoggerRef is simply a pointer. + if (!m_LoggerRef) + { + m_LoggerRef = zen::logging::Get(m_CategoryName); + } -LoggingContext::~LoggingContext() -{ + return m_LoggerRef; } -////////////////////////////////////////////////////////////////////////// - static inline bool -IsErrorLevel(int LogLevel) +IsErrorLevel(LogLevel InLevel) { - return (LogLevel == zen::logging::level::Err || LogLevel == zen::logging::level::Critical); + return (InLevel == Err || InLevel == Critical); }; -static_assert(sizeof(spdlog::source_loc) == sizeof(SourceLocation)); -static_assert(offsetof(spdlog::source_loc, filename) == offsetof(SourceLocation, filename)); -static_assert(offsetof(spdlog::source_loc, line) == offsetof(SourceLocation, line)); -static_assert(offsetof(spdlog::source_loc, funcname) == offsetof(SourceLocation, funcname)); - void -EmitLogMessage(LoggerRef& Logger, int LogLevel, const std::string_view Message) +EmitLogMessage(LoggerRef& Logger, const LogPoint& Lp, fmt::format_args Args) { ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - Logger.SpdLogger->log(InLevel, Message); - if (IsErrorLevel(LogLevel)) - { - if (LoggerRef ErrLogger = zen::logging::ErrorLog()) - { - ErrLogger.SpdLogger->log(InLevel, Message); - } - } -} -void -EmitLogMessage(LoggerRef& Logger, int LogLevel, std::string_view Format, fmt::format_args Args) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitLogMessage(Logger, LogLevel, LogCtx.Message()); -} + Logger->Log(Lp, Args); -void -EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, const std::string_view Message) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::source_loc& Location = *reinterpret_cast(&InLocation); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - Logger.SpdLogger->log(Location, InLevel, Message); - if (IsErrorLevel(LogLevel)) + if (IsErrorLevel(Lp.Level)) { if (LoggerRef ErrLogger = zen::logging::ErrorLog()) { - ErrLogger.SpdLogger->log(Location, InLevel, Message); + ErrLogger->Log(Lp, Args); } } } void -EmitLogMessage(LoggerRef& Logger, const SourceLocation& InLocation, int LogLevel, std::string_view Format, fmt::format_args Args) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitLogMessage(Logger, InLocation, LogLevel, LogCtx.Message()); -} - -void -EmitConsoleLogMessage(int LogLevel, const std::string_view Message) -{ - ZEN_MEMSCOPE(ELLMTag::Logging); - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - ConsoleLog().SpdLogger->log(InLevel, Message); -} - -#define ZEN_COLOR_YELLOW "\033[0;33m" -#define ZEN_COLOR_RED "\033[0;31m" -#define ZEN_BRIGHT_COLOR_RED "\033[1;31m" -#define ZEN_COLOR_RESET "\033[0m" - -void -EmitConsoleLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) +EmitConsoleLogMessage(const LogPoint& Lp, fmt::format_args Args) { ZEN_MEMSCOPE(ELLMTag::Logging); - zen::logging::LoggingContext LogCtx; - - // We are not using a format option for console which include log level since it would interfere with normal console output - - const spdlog::level::level_enum InLevel = (spdlog::level::level_enum)LogLevel; - switch (InLevel) - { - case spdlog::level::level_enum::warn: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET); - break; - case spdlog::level::level_enum::err: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET); - break; - case spdlog::level::level_enum::critical: - fmt::format_to(fmt::appender(LogCtx.MessageBuffer), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET); - break; - default: - break; - } - fmt::vformat_to(fmt::appender(LogCtx.MessageBuffer), Format, Args); - zen::logging::EmitConsoleLogMessage(LogLevel, LogCtx.Message()); + ConsoleLog()->Log(Lp, Args); } } // namespace zen::logging -namespace zen::logging::level { +namespace zen::logging { -spdlog::level::level_enum -to_spdlog_level(LogLevel NewLogLevel) -{ - return static_cast((int)NewLogLevel); -} +constinit std::string_view LevelNames[] = {std::string_view("trace", 5), + std::string_view("debug", 5), + std::string_view("info", 4), + std::string_view("warning", 7), + std::string_view("error", 5), + std::string_view("critical", 8), + std::string_view("off", 3)}; LogLevel -to_logging_level(spdlog::level::level_enum NewLogLevel) -{ - return static_cast((int)NewLogLevel); -} - -constinit std::string_view LevelNames[] = {ZEN_LEVEL_NAME_TRACE, - ZEN_LEVEL_NAME_DEBUG, - ZEN_LEVEL_NAME_INFO, - ZEN_LEVEL_NAME_WARNING, - ZEN_LEVEL_NAME_ERROR, - ZEN_LEVEL_NAME_CRITICAL, - ZEN_LEVEL_NAME_OFF}; - -level::LogLevel ParseLogLevelString(std::string_view Name) { - for (int Level = 0; Level < level::LogLevelCount; ++Level) + for (int Level = 0; Level < LogLevelCount; ++Level) { if (LevelNames[Level] == Name) - return static_cast(Level); + { + return static_cast(Level); + } } if (Name == "warn") { - return level::Warn; + return Warn; } if (Name == "err") { - return level::Err; + return Err; } - return level::Off; + return Off; } std::string_view -ToStringView(level::LogLevel Level) +ToStringView(LogLevel Level) { - if (int(Level) < level::LogLevelCount) + if (int(Level) < LogLevelCount) { return LevelNames[int(Level)]; } @@ -211,17 +131,17 @@ ToStringView(level::LogLevel Level) return "None"; } -} // namespace zen::logging::level +} // namespace zen::logging ////////////////////////////////////////////////////////////////////////// namespace zen::logging { RwLock LogLevelsLock; -std::string LogLevels[level::LogLevelCount]; +std::string LogLevels[LogLevelCount]; void -ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) +ConfigureLogLevels(LogLevel Level, std::string_view Loggers) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -230,18 +150,18 @@ ConfigureLogLevels(level::LogLevel Level, std::string_view Loggers) } void -RefreshLogLevels(level::LogLevel* DefaultLevel) +RefreshLogLevels(LogLevel* DefaultLevel) { ZEN_MEMSCOPE(ELLMTag::Logging); - spdlog::details::registry::log_levels Levels; + std::vector> Levels; { RwLock::SharedLockScope _(LogLevelsLock); - for (int i = 0; i < level::LogLevelCount; ++i) + for (int i = 0; i < LogLevelCount; ++i) { - level::LogLevel CurrentLevel{i}; + LogLevel CurrentLevel{i}; std::string_view Spec = LogLevels[i]; @@ -260,24 +180,16 @@ RefreshLogLevels(level::LogLevel* DefaultLevel) Spec = {}; } - Levels[LoggerName] = to_spdlog_level(CurrentLevel); + Levels.emplace_back(std::move(LoggerName), CurrentLevel); } } } - if (DefaultLevel) - { - spdlog::level::level_enum SpdDefaultLevel = to_spdlog_level(*DefaultLevel); - spdlog::details::registry::instance().set_levels(Levels, &SpdDefaultLevel); - } - else - { - spdlog::details::registry::instance().set_levels(Levels, nullptr); - } + Registry::Instance().SetLevels(Levels, DefaultLevel); } void -RefreshLogLevels(level::LogLevel DefaultLevel) +RefreshLogLevels(LogLevel DefaultLevel) { RefreshLogLevels(&DefaultLevel); } @@ -289,15 +201,15 @@ RefreshLogLevels() } void -SetLogLevel(level::LogLevel NewLogLevel) +SetLogLevel(LogLevel NewLogLevel) { - spdlog::set_level(to_spdlog_level(NewLogLevel)); + Registry::Instance().SetGlobalLevel(NewLogLevel); } -level::LogLevel +LogLevel GetLogLevel() { - return level::to_logging_level(spdlog::get_level()); + return Registry::Instance().GetGlobalLevel(); } LoggerRef @@ -312,10 +224,10 @@ SetDefault(std::string_view NewDefaultLoggerId) { ZEN_MEMSCOPE(ELLMTag::Logging); - auto NewDefaultLogger = spdlog::get(std::string(NewDefaultLoggerId)); + Ref NewDefaultLogger = Registry::Instance().Get(std::string(NewDefaultLoggerId)); ZEN_ASSERT(NewDefaultLogger); - spdlog::set_default_logger(NewDefaultLogger); + Registry::Instance().SetDefaultLogger(NewDefaultLogger); TheDefaultLogger = LoggerRef(*NewDefaultLogger); } @@ -338,11 +250,11 @@ SetErrorLog(std::string_view NewErrorLoggerId) } else { - auto NewErrorLogger = spdlog::get(std::string(NewErrorLoggerId)); + Ref NewErrorLogger = Registry::Instance().Get(std::string(NewErrorLoggerId)); ZEN_ASSERT(NewErrorLogger); - TheErrorLogger = LoggerRef(*NewErrorLogger.get()); + TheErrorLogger = LoggerRef(*NewErrorLogger.Get()); } } @@ -353,39 +265,75 @@ Get(std::string_view Name) { ZEN_MEMSCOPE(ELLMTag::Logging); - std::shared_ptr Logger = spdlog::get(std::string(Name)); + Ref FoundLogger = Registry::Instance().Get(std::string(Name)); - if (!Logger) + if (!FoundLogger) { g_LoggerMutex.WithExclusiveLock([&] { - Logger = spdlog::get(std::string(Name)); + FoundLogger = Registry::Instance().Get(std::string(Name)); - if (!Logger) + if (!FoundLogger) { - Logger = Default().SpdLogger->clone(std::string(Name)); - spdlog::apply_logger_env_levels(Logger); - spdlog::register_logger(Logger); + FoundLogger = Default()->Clone(std::string(Name)); + Registry::Instance().Register(FoundLogger); } }); } - return *Logger; + return *FoundLogger; } -std::once_flag ConsoleInitFlag; -std::shared_ptr ConLogger; +std::once_flag ConsoleInitFlag; +Ref ConLogger; void SuppressConsoleLog() { + ZEN_MEMSCOPE(ELLMTag::Logging); + if (ConLogger) { - spdlog::drop("console"); + Registry::Instance().Drop("console"); ConLogger = {}; } - ConLogger = spdlog::null_logger_mt("console"); + + SinkPtr NullSinkPtr(new NullSink()); + ConLogger = Ref(new Logger("console", std::vector{NullSinkPtr})); + Registry::Instance().Register(ConLogger); } +#define ZEN_COLOR_YELLOW "\033[0;33m" +#define ZEN_COLOR_RED "\033[0;31m" +#define ZEN_BRIGHT_COLOR_RED "\033[1;31m" +#define ZEN_COLOR_RESET "\033[0m" + +class ConsoleFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + switch (Msg.GetLevel()) + { + case Warn: + fmt::format_to(fmt::appender(Dest), ZEN_COLOR_YELLOW "Warning: " ZEN_COLOR_RESET); + break; + case Err: + fmt::format_to(fmt::appender(Dest), ZEN_BRIGHT_COLOR_RED "Error: " ZEN_COLOR_RESET); + break; + case Critical: + fmt::format_to(fmt::appender(Dest), ZEN_COLOR_RED "Critical: " ZEN_COLOR_RESET); + break; + default: + break; + } + + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr Clone() const override { return std::make_unique(); } +}; + LoggerRef ConsoleLog() { @@ -394,10 +342,10 @@ ConsoleLog() std::call_once(ConsoleInitFlag, [&] { if (!ConLogger) { - ConLogger = spdlog::stdout_color_mt("console"); - spdlog::apply_logger_env_levels(ConLogger); - - ConLogger->set_pattern("%v"); + SinkPtr ConsoleSink(new AnsiColorStdoutSink()); + ConsoleSink->SetFormatter(std::make_unique()); + ConLogger = Ref(new Logger("console", std::vector{ConsoleSink})); + Registry::Instance().Register(ConLogger); } }); @@ -407,9 +355,11 @@ ConsoleLog() void ResetConsoleLog() { + ZEN_MEMSCOPE(ELLMTag::Logging); + LoggerRef ConLog = ConsoleLog(); - ConLog.SpdLogger->set_pattern("%v"); + ConLog->SetFormatter(std::make_unique()); } void @@ -417,13 +367,15 @@ InitializeLogging() { ZEN_MEMSCOPE(ELLMTag::Logging); - TheDefaultLogger = *spdlog::default_logger_raw(); + TheDefaultLogger = *Registry::Instance().DefaultLoggerRaw(); } void ShutdownLogging() { - spdlog::shutdown(); + ZEN_MEMSCOPE(ELLMTag::Logging); + + Registry::Instance().Shutdown(); TheDefaultLogger = {}; } @@ -457,7 +409,7 @@ EnableVTMode() void FlushLogging() { - spdlog::details::registry::instance().flush_all(); + Registry::Instance().FlushAll(); } } // namespace zen::logging @@ -465,21 +417,27 @@ FlushLogging() namespace zen { bool -LoggerRef::ShouldLog(int Level) const +LoggerRef::ShouldLog(logging::LogLevel Level) const { - return SpdLogger->should_log(static_cast(Level)); + return m_Logger->ShouldLog(Level); } void -LoggerRef::SetLogLevel(logging::level::LogLevel NewLogLevel) +LoggerRef::SetLogLevel(logging::LogLevel NewLogLevel) { - SpdLogger->set_level(to_spdlog_level(NewLogLevel)); + m_Logger->SetLevel(NewLogLevel); } -logging::level::LogLevel +logging::LogLevel LoggerRef::GetLogLevel() { - return logging::level::to_logging_level(SpdLogger->level()); + return m_Logger->GetLevel(); +} + +void +LoggerRef::Flush() +{ + m_Logger->Flush(); } thread_local ScopedActivityBase* t_ScopeStack = nullptr; diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp new file mode 100644 index 000000000..9b9959862 --- /dev/null +++ b/src/zencore/logging/ansicolorsink.cpp @@ -0,0 +1,178 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include +#include + +#include +#include + +namespace zen::logging { + +// Default formatter replicating spdlog's %+ pattern: +// [YYYY-MM-DD HH:MM:SS.mmm] [logger_name] [level] message\n +class DefaultConsoleFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + // timestamp + auto Secs = std::chrono::duration_cast(Msg.GetTime().time_since_epoch()); + if (Secs != m_LastLogSecs) + { + m_LastLogSecs = Secs; + m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + } + + Dest.push_back('['); + helpers::AppendInt(m_CachedLocalTm.tm_year + 1900, Dest); + Dest.push_back('-'); + helpers::Pad2(m_CachedLocalTm.tm_mon + 1, Dest); + Dest.push_back('-'); + helpers::Pad2(m_CachedLocalTm.tm_mday, Dest); + Dest.push_back(' '); + helpers::Pad2(m_CachedLocalTm.tm_hour, Dest); + Dest.push_back(':'); + helpers::Pad2(m_CachedLocalTm.tm_min, Dest); + Dest.push_back(':'); + helpers::Pad2(m_CachedLocalTm.tm_sec, Dest); + Dest.push_back('.'); + auto Millis = helpers::TimeFraction(Msg.GetTime()); + helpers::Pad3(static_cast(Millis.count()), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + + // logger name + if (Msg.GetLoggerName().size() > 0) + { + Dest.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + } + + // level (colored range) + Dest.push_back('['); + Msg.ColorRangeStart = Dest.size(); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + Msg.ColorRangeEnd = Dest.size(); + Dest.push_back(']'); + Dest.push_back(' '); + + // message + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr Clone() const override { return std::make_unique(); } + +private: + std::chrono::seconds m_LastLogSecs{0}; + std::tm m_CachedLocalTm{}; +}; + +static constexpr std::string_view s_Reset = "\033[m"; + +static std::string_view +GetColorForLevel(LogLevel InLevel) +{ + using namespace std::string_view_literals; + switch (InLevel) + { + case Trace: + return "\033[37m"sv; // white + case Debug: + return "\033[36m"sv; // cyan + case Info: + return "\033[32m"sv; // green + case Warn: + return "\033[33m\033[1m"sv; // bold yellow + case Err: + return "\033[31m\033[1m"sv; // bold red + case Critical: + return "\033[1m\033[41m"sv; // bold on red background + default: + return s_Reset; + } +} + +struct AnsiColorStdoutSink::Impl +{ + Impl() : m_Formatter(std::make_unique()) {} + + void Log(const LogMessage& Msg) + { + std::lock_guard Lock(m_Mutex); + + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); + + if (Msg.ColorRangeEnd > Msg.ColorRangeStart) + { + // Print pre-color range + fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File); + + // Print color + std::string_view Color = GetColorForLevel(Msg.GetLevel()); + fwrite(Color.data(), 1, Color.size(), m_File); + + // Print colored range + fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File); + + // Reset color + fwrite(s_Reset.data(), 1, s_Reset.size(), m_File); + + // Print remainder + fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File); + } + else + { + fwrite(Formatted.data(), 1, Formatted.size(), m_File); + } + + fflush(m_File); + } + + void Flush() + { + std::lock_guard Lock(m_Mutex); + fflush(m_File); + } + + void SetFormatter(std::unique_ptr InFormatter) + { + std::lock_guard Lock(m_Mutex); + m_Formatter = std::move(InFormatter); + } + +private: + std::mutex m_Mutex; + std::unique_ptr m_Formatter; + FILE* m_File = stdout; +}; + +AnsiColorStdoutSink::AnsiColorStdoutSink() : m_Impl(std::make_unique()) +{ +} + +AnsiColorStdoutSink::~AnsiColorStdoutSink() = default; + +void +AnsiColorStdoutSink::Log(const LogMessage& Msg) +{ + m_Impl->Log(Msg); +} + +void +AnsiColorStdoutSink::Flush() +{ + m_Impl->Flush(); +} + +void +AnsiColorStdoutSink::SetFormatter(std::unique_ptr InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/asyncsink.cpp b/src/zencore/logging/asyncsink.cpp new file mode 100644 index 000000000..02bf9f3ba --- /dev/null +++ b/src/zencore/logging/asyncsink.cpp @@ -0,0 +1,212 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include +#include + +#include +#include +#include + +namespace zen::logging { + +struct AsyncLogMessage +{ + enum class Type : uint8_t + { + Log, + Flush, + Shutdown + }; + + Type MsgType = Type::Log; + + // Points to the LogPoint from upstream logging code. LogMessage guarantees + // this is always valid (either a static LogPoint from ZEN_LOG macros or one + // of the per-level default LogPoints). + const LogPoint* Point = nullptr; + + int ThreadId = 0; + std::string OwnedPayload; + std::string OwnedLoggerName; + std::chrono::system_clock::time_point Time; + + std::shared_ptr> FlushPromise; +}; + +struct AsyncSink::Impl +{ + explicit Impl(std::vector InSinks) : m_Sinks(std::move(InSinks)) + { + m_WorkerThread = std::thread([this]() { + zen::SetCurrentThreadName("AsyncLog"); + WorkerLoop(); + }); + } + + ~Impl() + { + AsyncLogMessage ShutdownMsg; + ShutdownMsg.MsgType = AsyncLogMessage::Type::Shutdown; + m_Queue.Enqueue(std::move(ShutdownMsg)); + + if (m_WorkerThread.joinable()) + { + m_WorkerThread.join(); + } + } + + void Log(const LogMessage& Msg) + { + AsyncLogMessage AsyncMsg; + AsyncMsg.OwnedPayload = std::string(Msg.GetPayload()); + AsyncMsg.OwnedLoggerName = std::string(Msg.GetLoggerName()); + AsyncMsg.ThreadId = Msg.GetThreadId(); + AsyncMsg.Time = Msg.GetTime(); + AsyncMsg.Point = &Msg.GetLogPoint(); + AsyncMsg.MsgType = AsyncLogMessage::Type::Log; + + m_Queue.Enqueue(std::move(AsyncMsg)); + } + + void Flush() + { + auto Promise = std::make_shared>(); + auto Future = Promise->get_future(); + + AsyncLogMessage FlushMsg; + FlushMsg.MsgType = AsyncLogMessage::Type::Flush; + FlushMsg.FlushPromise = std::move(Promise); + + m_Queue.Enqueue(std::move(FlushMsg)); + + Future.get(); + } + + void SetFormatter(std::unique_ptr InFormatter) + { + for (auto& CurrentSink : m_Sinks) + { + CurrentSink->SetFormatter(InFormatter->Clone()); + } + } + +private: + void ForwardLogToSinks(const AsyncLogMessage& AsyncMsg) + { + LogMessage Reconstructed(*AsyncMsg.Point, AsyncMsg.OwnedLoggerName, AsyncMsg.OwnedPayload); + Reconstructed.SetTime(AsyncMsg.Time); + Reconstructed.SetThreadId(AsyncMsg.ThreadId); + + for (auto& CurrentSink : m_Sinks) + { + if (CurrentSink->ShouldLog(Reconstructed.GetLevel())) + { + try + { + CurrentSink->Log(Reconstructed); + } + catch (const std::exception&) + { + } + } + } + } + + void FlushSinks() + { + for (auto& CurrentSink : m_Sinks) + { + try + { + CurrentSink->Flush(); + } + catch (const std::exception&) + { + } + } + } + + void WorkerLoop() + { + AsyncLogMessage Msg; + while (m_Queue.WaitAndDequeue(Msg)) + { + switch (Msg.MsgType) + { + case AsyncLogMessage::Type::Log: + { + ForwardLogToSinks(Msg); + break; + } + + case AsyncLogMessage::Type::Flush: + { + FlushSinks(); + if (Msg.FlushPromise) + { + Msg.FlushPromise->set_value(); + } + break; + } + + case AsyncLogMessage::Type::Shutdown: + { + m_Queue.CompleteAdding(); + + AsyncLogMessage Remaining; + while (m_Queue.WaitAndDequeue(Remaining)) + { + if (Remaining.MsgType == AsyncLogMessage::Type::Log) + { + ForwardLogToSinks(Remaining); + } + else if (Remaining.MsgType == AsyncLogMessage::Type::Flush) + { + FlushSinks(); + if (Remaining.FlushPromise) + { + Remaining.FlushPromise->set_value(); + } + } + } + + FlushSinks(); + return; + } + } + } + } + + std::vector m_Sinks; + BlockingQueue m_Queue; + std::thread m_WorkerThread; +}; + +AsyncSink::AsyncSink(std::vector InSinks) : m_Impl(std::make_unique(std::move(InSinks))) +{ +} + +AsyncSink::~AsyncSink() = default; + +void +AsyncSink::Log(const LogMessage& Msg) +{ + m_Impl->Log(Msg); +} + +void +AsyncSink::Flush() +{ + m_Impl->Flush(); +} + +void +AsyncSink::SetFormatter(std::unique_ptr InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/logger.cpp b/src/zencore/logging/logger.cpp new file mode 100644 index 000000000..dd1675bb1 --- /dev/null +++ b/src/zencore/logging/logger.cpp @@ -0,0 +1,142 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include + +#include +#include + +namespace zen::logging { + +struct Logger::Impl +{ + std::string m_Name; + std::vector m_Sinks; + ErrorHandler* m_ErrorHandler = nullptr; +}; + +Logger::Logger(std::string_view InName, SinkPtr InSink) : m_Impl(std::make_unique()) +{ + m_Impl->m_Name = InName; + m_Impl->m_Sinks.push_back(std::move(InSink)); +} + +Logger::Logger(std::string_view InName, std::span InSinks) : m_Impl(std::make_unique()) +{ + m_Impl->m_Name = InName; + m_Impl->m_Sinks.assign(InSinks.begin(), InSinks.end()); +} + +Logger::~Logger() = default; + +void +Logger::Log(const LogPoint& Point, fmt::format_args Args) +{ + if (!ShouldLog(Point.Level)) + { + return; + } + + fmt::basic_memory_buffer Buffer; + fmt::vformat_to(fmt::appender(Buffer), Point.FormatString, Args); + + LogMessage LogMsg(Point, m_Impl->m_Name, std::string_view(Buffer.data(), Buffer.size())); + LogMsg.SetThreadId(GetCurrentThreadId()); + SinkIt(LogMsg); + FlushIfNeeded(Point.Level); +} + +void +Logger::SinkIt(const LogMessage& Msg) +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + if (CurrentSink->ShouldLog(Msg.GetLevel())) + { + try + { + CurrentSink->Log(Msg); + } + catch (const std::exception& Ex) + { + if (m_Impl->m_ErrorHandler) + { + m_Impl->m_ErrorHandler->HandleError(Ex.what()); + } + } + } + } +} + +void +Logger::FlushIfNeeded(LogLevel InLevel) +{ + if (InLevel >= m_FlushLevel.load(std::memory_order_relaxed)) + { + Flush(); + } +} + +void +Logger::Flush() +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + try + { + CurrentSink->Flush(); + } + catch (const std::exception& Ex) + { + if (m_Impl->m_ErrorHandler) + { + m_Impl->m_ErrorHandler->HandleError(Ex.what()); + } + } + } +} + +void +Logger::SetSinks(std::vector InSinks) +{ + m_Impl->m_Sinks = std::move(InSinks); +} + +void +Logger::AddSink(SinkPtr InSink) +{ + m_Impl->m_Sinks.push_back(std::move(InSink)); +} + +void +Logger::SetErrorHandler(ErrorHandler* Handler) +{ + m_Impl->m_ErrorHandler = Handler; +} + +void +Logger::SetFormatter(std::unique_ptr InFormatter) +{ + for (auto& CurrentSink : m_Impl->m_Sinks) + { + CurrentSink->SetFormatter(InFormatter->Clone()); + } +} + +std::string_view +Logger::Name() const +{ + return m_Impl->m_Name; +} + +Ref +Logger::Clone(std::string_view NewName) const +{ + Ref Cloned(new Logger(NewName, m_Impl->m_Sinks)); + Cloned->SetLevel(m_Level.load(std::memory_order_relaxed)); + Cloned->SetFlushLevel(m_FlushLevel.load(std::memory_order_relaxed)); + Cloned->SetErrorHandler(m_Impl->m_ErrorHandler); + return Cloned; +} + +} // namespace zen::logging diff --git a/src/zencore/logging/msvcsink.cpp b/src/zencore/logging/msvcsink.cpp new file mode 100644 index 000000000..457a4d6e1 --- /dev/null +++ b/src/zencore/logging/msvcsink.cpp @@ -0,0 +1,80 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#if ZEN_PLATFORM_WINDOWS + +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::logging { + +// Default formatter for MSVC debug output: [level] message\n +// For error/critical messages with source info, prepends file(line): so that +// the message is clickable in the Visual Studio Output window. +class DefaultMsvcFormatter : public Formatter +{ +public: + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override + { + const auto& Source = Msg.GetSource(); + if (Msg.GetLevel() >= LogLevel::Err && Source) + { + helpers::AppendStringView(Source.Filename, Dest); + Dest.push_back('('); + helpers::AppendInt(Source.Line, Dest); + Dest.push_back(')'); + Dest.push_back(':'); + Dest.push_back(' '); + } + + Dest.push_back('['); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + Dest.push_back(']'); + Dest.push_back(' '); + helpers::AppendStringView(Msg.GetPayload(), Dest); + Dest.push_back('\n'); + } + + std::unique_ptr Clone() const override { return std::make_unique(); } +}; + +MsvcSink::MsvcSink() : m_Formatter(std::make_unique()) +{ +} + +void +MsvcSink::Log(const LogMessage& Msg) +{ + std::lock_guard Lock(m_Mutex); + + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); + + // Null-terminate for OutputDebugStringA + Formatted.push_back('\0'); + + OutputDebugStringA(Formatted.data()); +} + +void +MsvcSink::Flush() +{ + // Nothing to flush for OutputDebugString +} + +void +MsvcSink::SetFormatter(std::unique_ptr InFormatter) +{ + std::lock_guard Lock(m_Mutex); + m_Formatter = std::move(InFormatter); +} + +} // namespace zen::logging + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zencore/logging/registry.cpp b/src/zencore/logging/registry.cpp new file mode 100644 index 000000000..3ed1fb0df --- /dev/null +++ b/src/zencore/logging/registry.cpp @@ -0,0 +1,330 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace zen::logging { + +struct Registry::Impl +{ + Impl() + { + // Create default logger with a stdout color sink + SinkPtr DefaultSink(new AnsiColorStdoutSink()); + m_DefaultLogger = Ref(new Logger("", DefaultSink)); + m_Loggers[""] = m_DefaultLogger; + } + + ~Impl() { StopPeriodicFlush(); } + + void Register(Ref InLogger) + { + std::lock_guard Lock(m_Mutex); + if (m_ErrorHandler) + { + InLogger->SetErrorHandler(m_ErrorHandler); + } + m_Loggers[std::string(InLogger->Name())] = std::move(InLogger); + } + + void Drop(const std::string& Name) + { + std::lock_guard Lock(m_Mutex); + m_Loggers.erase(Name); + } + + Ref Get(const std::string& Name) + { + std::lock_guard Lock(m_Mutex); + auto It = m_Loggers.find(Name); + if (It != m_Loggers.end()) + { + return It->second; + } + return {}; + } + + void SetDefaultLogger(Ref InLogger) + { + std::lock_guard Lock(m_Mutex); + if (InLogger) + { + m_Loggers[std::string(InLogger->Name())] = InLogger; + } + m_DefaultLogger = std::move(InLogger); + } + + Logger* DefaultLoggerRaw() { return m_DefaultLogger.Get(); } + + Ref DefaultLogger() + { + std::lock_guard Lock(m_Mutex); + return m_DefaultLogger; + } + + void SetGlobalLevel(LogLevel Level) + { + m_GlobalLevel.store(Level, std::memory_order_relaxed); + std::lock_guard Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetLevel(Level); + } + } + + LogLevel GetGlobalLevel() const { return m_GlobalLevel.load(std::memory_order_relaxed); } + + void SetLevels(Registry::LogLevels Levels, LogLevel* DefaultLevel) + { + std::lock_guard Lock(m_Mutex); + + if (DefaultLevel) + { + m_GlobalLevel.store(*DefaultLevel, std::memory_order_relaxed); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetLevel(*DefaultLevel); + } + } + + for (auto& [LoggerName, Level] : Levels) + { + auto It = m_Loggers.find(LoggerName); + if (It != m_Loggers.end()) + { + It->second->SetLevel(Level); + } + } + } + + void FlushAll() + { + std::lock_guard Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + try + { + CurLogger->Flush(); + } + catch (const std::exception&) + { + } + } + } + + void FlushOn(LogLevel Level) + { + std::lock_guard Lock(m_Mutex); + m_FlushLevel = Level; + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetFlushLevel(Level); + } + } + + void FlushEvery(std::chrono::seconds Interval) + { + StopPeriodicFlush(); + + m_PeriodicFlushRunning.store(true, std::memory_order_relaxed); + + m_FlushThread = std::thread([this, Interval] { + while (m_PeriodicFlushRunning.load(std::memory_order_relaxed)) + { + { + std::unique_lock Lock(m_PeriodicFlushMutex); + m_PeriodicFlushCv.wait_for(Lock, Interval, [this] { return !m_PeriodicFlushRunning.load(std::memory_order_relaxed); }); + } + + if (m_PeriodicFlushRunning.load(std::memory_order_relaxed)) + { + FlushAll(); + } + } + }); + } + + void SetFormatter(std::unique_ptr InFormatter) + { + std::lock_guard Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetFormatter(InFormatter->Clone()); + } + } + + void ApplyAll(void (*Func)(void*, Ref), void* Context) + { + std::lock_guard Lock(m_Mutex); + for (auto& [Name, CurLogger] : m_Loggers) + { + Func(Context, CurLogger); + } + } + + void SetErrorHandler(ErrorHandler* Handler) + { + std::lock_guard Lock(m_Mutex); + m_ErrorHandler = Handler; + for (auto& [Name, CurLogger] : m_Loggers) + { + CurLogger->SetErrorHandler(Handler); + } + } + + void Shutdown() + { + StopPeriodicFlush(); + FlushAll(); + + std::lock_guard Lock(m_Mutex); + m_Loggers.clear(); + m_DefaultLogger = nullptr; + } + +private: + void StopPeriodicFlush() + { + if (m_FlushThread.joinable()) + { + m_PeriodicFlushRunning.store(false, std::memory_order_relaxed); + { + std::lock_guard Lock(m_PeriodicFlushMutex); + m_PeriodicFlushCv.notify_one(); + } + m_FlushThread.join(); + } + } + + std::mutex m_Mutex; + std::unordered_map> m_Loggers; + Ref m_DefaultLogger; + std::atomic m_GlobalLevel{Trace}; + LogLevel m_FlushLevel{Off}; + ErrorHandler* m_ErrorHandler = nullptr; + + // Periodic flush + std::atomic m_PeriodicFlushRunning{false}; + std::mutex m_PeriodicFlushMutex; + std::condition_variable m_PeriodicFlushCv; + std::thread m_FlushThread; +}; + +Registry& +Registry::Instance() +{ + static Registry s_Instance; + return s_Instance; +} + +Registry::Registry() : m_Impl(std::make_unique()) +{ +} + +Registry::~Registry() = default; + +void +Registry::Register(Ref InLogger) +{ + m_Impl->Register(std::move(InLogger)); +} + +void +Registry::Drop(const std::string& Name) +{ + m_Impl->Drop(Name); +} + +Ref +Registry::Get(const std::string& Name) +{ + return m_Impl->Get(Name); +} + +void +Registry::SetDefaultLogger(Ref InLogger) +{ + m_Impl->SetDefaultLogger(std::move(InLogger)); +} + +Logger* +Registry::DefaultLoggerRaw() +{ + return m_Impl->DefaultLoggerRaw(); +} + +Ref +Registry::DefaultLogger() +{ + return m_Impl->DefaultLogger(); +} + +void +Registry::SetGlobalLevel(LogLevel Level) +{ + m_Impl->SetGlobalLevel(Level); +} + +LogLevel +Registry::GetGlobalLevel() const +{ + return m_Impl->GetGlobalLevel(); +} + +void +Registry::SetLevels(LogLevels Levels, LogLevel* DefaultLevel) +{ + m_Impl->SetLevels(Levels, DefaultLevel); +} + +void +Registry::FlushAll() +{ + m_Impl->FlushAll(); +} + +void +Registry::FlushOn(LogLevel Level) +{ + m_Impl->FlushOn(Level); +} + +void +Registry::FlushEvery(std::chrono::seconds Interval) +{ + m_Impl->FlushEvery(Interval); +} + +void +Registry::SetFormatter(std::unique_ptr InFormatter) +{ + m_Impl->SetFormatter(std::move(InFormatter)); +} + +void +Registry::ApplyAllImpl(void (*Func)(void*, Ref), void* Context) +{ + m_Impl->ApplyAll(Func, Context); +} + +void +Registry::SetErrorHandler(ErrorHandler* Handler) +{ + m_Impl->SetErrorHandler(Handler); +} + +void +Registry::Shutdown() +{ + m_Impl->Shutdown(); +} + +} // namespace zen::logging diff --git a/src/zencore/logging/tracesink.cpp b/src/zencore/logging/tracesink.cpp new file mode 100644 index 000000000..e3533327b --- /dev/null +++ b/src/zencore/logging/tracesink.cpp @@ -0,0 +1,88 @@ + +// Copyright Epic Games, Inc. All Rights Reserved. + +#include +#include +#include +#include +#include + +namespace zen::logging { + +UE_TRACE_CHANNEL_DEFINE(LogChannel) + +UE_TRACE_EVENT_BEGIN(Logging, LogCategory, NoSync | Important) + UE_TRACE_EVENT_FIELD(const void*, CategoryPointer) + UE_TRACE_EVENT_FIELD(uint8_t, DefaultVerbosity) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, Name) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Logging, LogMessageSpec, NoSync | Important) + UE_TRACE_EVENT_FIELD(const void*, LogPoint) + UE_TRACE_EVENT_FIELD(const void*, CategoryPointer) + UE_TRACE_EVENT_FIELD(int32_t, Line) + UE_TRACE_EVENT_FIELD(uint8_t, Verbosity) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FileName) + UE_TRACE_EVENT_FIELD(UE::Trace::AnsiString, FormatString) +UE_TRACE_EVENT_END() + +UE_TRACE_EVENT_BEGIN(Logging, LogMessage, NoSync) + UE_TRACE_EVENT_FIELD(const void*, LogPoint) + UE_TRACE_EVENT_FIELD(uint64_t, Cycle) + UE_TRACE_EVENT_FIELD(uint8_t[], FormatArgs) +UE_TRACE_EVENT_END() + +void +TraceLogCategory(const logging::Logger* Category, const char* Name, logging::LogLevel DefaultVerbosity) +{ + uint16_t NameLen = uint16_t(strlen(Name)); + UE_TRACE_LOG(Logging, LogCategory, LogChannel, NameLen * sizeof(ANSICHAR)) + << LogCategory.CategoryPointer(Category) << LogCategory.DefaultVerbosity(uint8_t(DefaultVerbosity)) + << LogCategory.Name(Name, NameLen); +} + +void +TraceLogMessageSpec(const void* LogPoint, + const logging::Logger* Category, + logging::LogLevel Verbosity, + const std::string_view File, + int32_t Line, + const std::string_view Format) +{ + uint16_t FileNameLen = uint16_t(File.size()); + uint16_t FormatStringLen = uint16_t(Format.size()); + uint32_t DataSize = (FileNameLen * sizeof(ANSICHAR)) + (FormatStringLen * sizeof(ANSICHAR)); + UE_TRACE_LOG(Logging, LogMessageSpec, LogChannel, DataSize) + << LogMessageSpec.LogPoint(LogPoint) << LogMessageSpec.CategoryPointer(Category) << LogMessageSpec.Line(Line) + << LogMessageSpec.Verbosity(uint8_t(Verbosity)) << LogMessageSpec.FileName(File.data(), FileNameLen) + << LogMessageSpec.FormatString(Format.data(), FormatStringLen); +} + +void +TraceLogMessageInternal(const void* LogPoint, int32_t EncodedFormatArgsSize, const uint8_t* EncodedFormatArgs) +{ + UE_TRACE_LOG(Logging, LogMessage, LogChannel) << LogMessage.LogPoint(LogPoint) << LogMessage.Cycle(GetHifreqTimerValue()) + << LogMessage.FormatArgs(EncodedFormatArgs, EncodedFormatArgsSize); +} + +////////////////////////////////////////////////////////////////////////// + +void +TraceSink::Log(const LogMessage& Msg) +{ + ZEN_UNUSED(Msg); +} + +void +TraceSink::Flush() +{ +} + +void +TraceSink::SetFormatter(std::unique_ptr /*InFormatter*/) +{ + // This sink doesn't use a formatter since it just forwards the raw format + // args to the trace system +} + +} // namespace zen::logging diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index bfff114c3..e39b8438d 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -4,29 +4,23 @@ #include #include +#include +#include #include #include #include #include -#if ZEN_PLATFORM_LINUX +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC # include +# include #endif -#if ZEN_PLATFORM_MAC -# include -#endif - -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END - #if ZEN_USE_SENTRY # define SENTRY_BUILD_STATIC 1 ZEN_THIRD_PARTY_INCLUDES_START # include -# include ZEN_THIRD_PARTY_INCLUDES_END namespace sentry { @@ -44,76 +38,58 @@ struct SentryAssertImpl : zen::AssertImpl const zen::CallstackFrames* Callstack) override; }; -class sentry_sink final : public spdlog::sinks::base_sink +static constexpr sentry_level_t MapToSentryLevel[zen::logging::LogLevelCount] = {SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_INFO, + SENTRY_LEVEL_WARNING, + SENTRY_LEVEL_ERROR, + SENTRY_LEVEL_FATAL, + SENTRY_LEVEL_DEBUG}; + +class SentrySink final : public zen::logging::Sink { public: - sentry_sink(); - ~sentry_sink(); + SentrySink() = default; + ~SentrySink() = default; + + void Log(const zen::logging::LogMessage& Msg) override + { + if (Msg.GetLevel() != zen::logging::Err && Msg.GetLevel() != zen::logging::Critical) + { + return; + } + try + { + std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Msg.GetSource().Filename, Msg.GetSource().Line); + sentry_value_t Event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[Msg.GetLevel()], + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(Event, NULL, 0); + sentry_capture_event(Event); + } + catch (const std::exception&) + { + // If our logging with Message formatting fails we do a non-allocating version and just post the payload raw + char TmpBuffer[256]; + size_t MaxCopy = zen::Min(Msg.GetPayload().size(), size_t(255)); + memcpy(TmpBuffer, Msg.GetPayload().data(), MaxCopy); + TmpBuffer[MaxCopy] = '\0'; + sentry_value_t Event = sentry_value_new_message_event( + /* level */ SENTRY_LEVEL_ERROR, + /* logger */ nullptr, + /* message */ TmpBuffer); + sentry_event_value_add_stacktrace(Event, NULL, 0); + sentry_capture_event(Event); + } + } -protected: - void sink_it_(const spdlog::details::log_msg& msg) override; - void flush_() override; + void Flush() override {} + void SetFormatter(std::unique_ptr) override {} }; ////////////////////////////////////////////////////////////////////////// -static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG, - SENTRY_LEVEL_DEBUG, - SENTRY_LEVEL_INFO, - SENTRY_LEVEL_WARNING, - SENTRY_LEVEL_ERROR, - SENTRY_LEVEL_FATAL, - SENTRY_LEVEL_DEBUG}; - -sentry_sink::sentry_sink() -{ -} -sentry_sink::~sentry_sink() -{ -} - -void -sentry_sink::sink_it_(const spdlog::details::log_msg& msg) -{ - if (msg.level != spdlog::level::err && msg.level != spdlog::level::critical) - { - return; - } - try - { - auto MaybeNullString = [](const char* Ptr) { return Ptr ? Ptr : ""; }; - std::string Message = fmt::format("{}\n{}({}) [{}]", - msg.payload, - MaybeNullString(msg.source.filename), - msg.source.line, - MaybeNullString(msg.source.funcname)); - sentry_value_t event = sentry_value_new_message_event( - /* level */ MapToSentryLevel[msg.level], - /* logger */ nullptr, - /* message */ Message.c_str()); - sentry_event_value_add_stacktrace(event, NULL, 0); - sentry_capture_event(event); - } - catch (const std::exception&) - { - // If our logging with Message formatting fails we do a non-allocating version and just post the msg.payload raw - char TmpBuffer[256]; - size_t MaxCopy = zen::Min(msg.payload.size(), size_t(255)); - memcpy(TmpBuffer, msg.payload.data(), MaxCopy); - TmpBuffer[MaxCopy] = '\0'; - sentry_value_t event = sentry_value_new_message_event( - /* level */ SENTRY_LEVEL_ERROR, - /* logger */ nullptr, - /* message */ TmpBuffer); - sentry_event_value_add_stacktrace(event, NULL, 0); - sentry_capture_event(event); - } -} -void -sentry_sink::flush_() -{ -} - void SentryAssertImpl::OnAssert(const char* Filename, int LineNumber, @@ -340,7 +316,9 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine sentry_set_user(SentryUserObject); - m_SentryLogger = spdlog::create("sentry"); + logging::SinkPtr SentrySink(new sentry::SentrySink()); + m_SentryLogger = Ref(new logging::Logger("sentry", std::vector{SentrySink})); + logging::Registry::Instance().Register(m_SentryLogger); logging::SetErrorLog("sentry"); m_SentryAssert = std::make_unique(); @@ -354,7 +332,7 @@ SentryIntegration::LogStartupInformation() { // Initialize the sentry-sdk log category at Warn level to reduce startup noise. // The level can be overridden via --log-debug=sentry-sdk or --log-info=sentry-sdk - LogSentry.Logger().SetLogLevel(logging::level::Warn); + LogSentry.Logger().SetLogLevel(logging::Warn); if (m_IsInitialized) { diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 0bae139bd..089e376bb 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -143,11 +143,11 @@ TestRunner::ApplyCommandLine(int Argc, char const* const* Argv) { if (Argv[i] == "--debug"sv) { - zen::logging::SetLogLevel(zen::logging::level::Debug); + zen::logging::SetLogLevel(zen::logging::Debug); } else if (Argv[i] == "--verbose"sv) { - zen::logging::SetLogLevel(zen::logging::level::Trace); + zen::logging::SetLogLevel(zen::logging::Trace); } } diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index 2f81b7ec8..171f4c533 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -26,7 +26,6 @@ target('zencore') end add_deps("zenbase") - add_deps("spdlog") add_deps("utfcpp") add_deps("oodle") add_deps("blake3") diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp index d82474705..8c29a8962 100644 --- a/src/zencore/zencore.cpp +++ b/src/zencore/zencore.cpp @@ -147,7 +147,7 @@ AssertImpl::OnAssert(const char* Filename, int LineNumber, const char* FunctionN Message.push_back('\0'); // We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log - ZEN_LOG(Log(), zen::logging::level::Err, "{}", Message.data()); + ZEN_LOG(Log(), zen::logging::Err, "{}", Message.data()); zen::logging::FlushLogging(); } diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 33f182df9..2cf051d14 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -150,7 +150,7 @@ inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.SetLogLevel(logging::level::Trace); + // Logger.SetLogLevel(logging::Trace); return Logger; } @@ -1256,7 +1256,7 @@ HttpServerConnection::HandleRequest() const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server.m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server.m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 850dafdca..021b941bd 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -383,7 +383,7 @@ HttpPluginConnectionHandler::HandleRequest() const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -480,7 +480,7 @@ HttpPluginConnectionHandler::HandleRequest() const std::vector& ResponseBuffers = Response->ResponseBuffers(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers); diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp index 9135d5425..489324aba 100644 --- a/src/zenhttp/transports/dlltransport.cpp +++ b/src/zenhttp/transports/dlltransport.cpp @@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa void DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { - logging::level::LogLevel Level; - // clang-format off switch (PluginLogLevel) { - case LogLevel::Trace: Level = logging::level::Trace; break; - case LogLevel::Debug: Level = logging::level::Debug; break; - case LogLevel::Info: Level = logging::level::Info; break; - case LogLevel::Warn: Level = logging::level::Warn; break; - case LogLevel::Err: Level = logging::level::Err; break; - case LogLevel::Critical: Level = logging::level::Critical; break; - default: Level = logging::level::Off; break; + case LogLevel::Trace: + ZEN_TRACE("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Debug: + ZEN_DEBUG("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Info: + ZEN_INFO("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Warn: + ZEN_WARN("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Err: + ZEN_ERROR("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Critical: + ZEN_CRITICAL("[{}] {}", m_PluginName, Message); + return; + + default: + ZEN_UNUSED(Message); + break; } - // clang-format on - ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t diff --git a/src/zenremotestore/include/zenremotestore/operationlogoutput.h b/src/zenremotestore/include/zenremotestore/operationlogoutput.h index 6f10ab156..32b95f50f 100644 --- a/src/zenremotestore/include/zenremotestore/operationlogoutput.h +++ b/src/zenremotestore/include/zenremotestore/operationlogoutput.h @@ -11,7 +11,7 @@ class OperationLogOutput { public: virtual ~OperationLogOutput() {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) = 0; + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) = 0; virtual void SetLogOperationName(std::string_view Name) = 0; virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) = 0; @@ -60,19 +60,17 @@ public: OperationLogOutput* CreateStandardLogOutput(LoggerRef Log); -#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ - do \ - { \ - using namespace std::literals; \ - ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ - OutputTarget.EmitLogMessage(InLevel, fmtstr, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ +#define ZEN_OPERATION_LOG(OutputTarget, InLevel, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + static constinit zen::logging::LogPoint LogPoint{{}, InLevel, std::string_view(fmtstr)}; \ + ZEN_CHECK_FORMAT_STRING(fmtstr##sv, ##__VA_ARGS__); \ + (OutputTarget).EmitLogMessage(LogPoint, zen::logging::LogCaptureArguments(__VA_ARGS__)); \ } while (false) -#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Info, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Debug, fmtstr, ##__VA_ARGS__) -#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) \ - ZEN_OPERATION_LOG((OutputTarget), zen::logging::level::Warn, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_INFO(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Info, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_DEBUG(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Debug, fmtstr, ##__VA_ARGS__) +#define ZEN_OPERATION_LOG_WARN(OutputTarget, fmtstr, ...) ZEN_OPERATION_LOG(OutputTarget, zen::logging::Warn, fmtstr, ##__VA_ARGS__) } // namespace zen diff --git a/src/zenremotestore/operationlogoutput.cpp b/src/zenremotestore/operationlogoutput.cpp index 7ed93c947..5ed844c9d 100644 --- a/src/zenremotestore/operationlogoutput.cpp +++ b/src/zenremotestore/operationlogoutput.cpp @@ -3,6 +3,7 @@ #include #include +#include ZEN_THIRD_PARTY_INCLUDES_START #include @@ -30,13 +31,11 @@ class StandardLogOutput : public OperationLogOutput { public: StandardLogOutput(LoggerRef& Log) : m_Log(Log) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override { - if (m_Log.ShouldLog(LogLevel)) + if (m_Log.ShouldLog(Point.Level)) { - fmt::basic_memory_buffer MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); - ZEN_LOG(m_Log, LogLevel, "{}", std::string_view(MessageBuffer.data(), MessageBuffer.size())); + m_Log->Log(Point, Args); } } @@ -47,7 +46,7 @@ public: } virtual void SetLogOperationProgress(uint32_t StepIndex, uint32_t StepCount) override { - const size_t PercentDone = StepCount > 0u ? gsl::narrow((100 * StepIndex) / StepCount) : 0u; + [[maybe_unused]] const size_t PercentDone = StepCount > 0u ? gsl::narrow((100 * StepIndex) / StepCount) : 0u; ZEN_OPERATION_LOG_INFO(*this, "{}: {}%", m_LogOperationName, PercentDone); } virtual uint32_t GetProgressUpdateDelayMS() override { return 2000; } @@ -59,13 +58,14 @@ public: private: LoggerRef m_Log; std::string m_LogOperationName; + LoggerRef Log() { return m_Log; } }; void StandardLogOutputProgressBar::UpdateState(const State& NewState, bool DoLinebreak) { ZEN_UNUSED(DoLinebreak); - const size_t PercentDone = + [[maybe_unused]] const size_t PercentDone = NewState.TotalCount > 0u ? gsl::narrow((100 * (NewState.TotalCount - NewState.RemainingCount)) / NewState.TotalCount) : 0u; std::string Task = NewState.Task; switch (NewState.Status) diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index 78f6014df..c8c5f201d 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -246,13 +246,12 @@ namespace remotestore_impl { { public: JobContextLogOutput(JobContext* OptionalContext) : m_OptionalContext(OptionalContext) {} - virtual void EmitLogMessage(int LogLevel, std::string_view Format, fmt::format_args Args) override + virtual void EmitLogMessage(const logging::LogPoint& Point, fmt::format_args Args) override { - ZEN_UNUSED(LogLevel); if (m_OptionalContext) { fmt::basic_memory_buffer MessageBuffer; - fmt::vformat_to(fmt::appender(MessageBuffer), Format, Args); + fmt::vformat_to(fmt::appender(MessageBuffer), Point.FormatString, Args); remotestore_impl::ReportMessage(m_OptionalContext, std::string_view(MessageBuffer.data(), MessageBuffer.size())); } } diff --git a/src/zenserver-test/logging-tests.cpp b/src/zenserver-test/logging-tests.cpp index f284f0371..2e530ff92 100644 --- a/src/zenserver-test/logging-tests.cpp +++ b/src/zenserver-test/logging-tests.cpp @@ -71,7 +71,7 @@ TEST_CASE("logging.file.default") // entry written by the default logger's console sink must therefore not appear // in captured stdout. (The "console" named logger — used by ZEN_CONSOLE_* // macros — may still emit plain-text messages without a level marker, so we -// check for the absence of the full_formatter "[info]" prefix rather than the +// check for the absence of the FullFormatter "[info]" prefix rather than the // message text itself.) TEST_CASE("logging.console.quiet") { diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index bd36d731f..8d5400294 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -9,6 +9,7 @@ # include # include # include +# include # include # include # include @@ -17,7 +18,7 @@ # include # include # include -# include +# include # include # include @@ -85,8 +86,9 @@ main(int argc, char** argv) zen::logging::InitializeLogging(); - zen::logging::SetLogLevel(zen::logging::level::Debug); - spdlog::set_formatter(std::make_unique("test", std::chrono::system_clock::now())); + zen::logging::SetLogLevel(zen::logging::Debug); + zen::logging::Registry::Instance().SetFormatter( + std::make_unique("test", std::chrono::system_clock::now())); std::filesystem::path ProgramBaseDir = GetRunningExecutablePath().parent_path(); std::filesystem::path TestBaseDir = std::filesystem::current_path() / ".test"; diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp index d8d53b0e3..5fa81ff9f 100644 --- a/src/zenserver/diag/diagsvcs.cpp +++ b/src/zenserver/diag/diagsvcs.cpp @@ -12,9 +12,7 @@ #include #include -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END +#include namespace zen { @@ -64,7 +62,7 @@ HttpHealthService::HttpHealthService() [this](HttpRouterRequest& RoutedReq) { HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); - zen::Log().SpdLogger->flush(); + zen::Log().Flush(); std::filesystem::path Path = [&] { RwLock::SharedLockScope _(m_InfoLock); diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp index 75a8efc09..178c3d3b5 100644 --- a/src/zenserver/diag/logging.cpp +++ b/src/zenserver/diag/logging.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include #include @@ -14,10 +16,6 @@ #include "otlphttp.h" -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen { void @@ -43,13 +41,12 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) std::filesystem::path HttpLogPath = InOptions.DataDir / "logs" / "http.log"; zen::CreateDirectories(HttpLogPath.parent_path()); - auto HttpSink = std::make_shared(HttpLogPath, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ true); - auto HttpLogger = std::make_shared("http_requests", HttpSink); - spdlog::apply_logger_env_levels(HttpLogger); - spdlog::register_logger(HttpLogger); + logging::SinkPtr HttpSink(new zen::logging::RotatingFileSink(HttpLogPath, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true)); + Ref HttpLogger(new logging::Logger("http_requests", std::vector{HttpSink})); + logging::Registry::Instance().Register(HttpLogger); if (WithCacheService) { @@ -57,33 +54,30 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) std::filesystem::path CacheLogPath = InOptions.DataDir / "logs" / "z$.log"; zen::CreateDirectories(CacheLogPath.parent_path()); - auto CacheSink = std::make_shared(CacheLogPath, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ false); - auto CacheLogger = std::make_shared("z$", CacheSink); - spdlog::apply_logger_env_levels(CacheLogger); - spdlog::register_logger(CacheLogger); + logging::SinkPtr CacheSink(new zen::logging::RotatingFileSink(CacheLogPath, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ false)); + Ref CacheLogger(new logging::Logger("z$", std::vector{CacheSink})); + logging::Registry::Instance().Register(CacheLogger); // Jupiter - only log upstream HTTP traffic to file - auto JupiterLogger = std::make_shared("jupiter", FileSink); - spdlog::apply_logger_env_levels(JupiterLogger); - spdlog::register_logger(JupiterLogger); + Ref JupiterLogger(new logging::Logger("jupiter", std::vector{FileSink})); + logging::Registry::Instance().Register(JupiterLogger); // Zen - only log upstream HTTP traffic to file - auto ZenClientLogger = std::make_shared("zenclient", FileSink); - spdlog::apply_logger_env_levels(ZenClientLogger); - spdlog::register_logger(ZenClientLogger); + Ref ZenClientLogger(new logging::Logger("zenclient", std::vector{FileSink})); + logging::Registry::Instance().Register(ZenClientLogger); } #if ZEN_WITH_OTEL if (!InOptions.LoggingConfig.OtelEndpointUri.empty()) { // TODO: Should sanity check that endpoint is reachable? Also, a valid URI? - auto OtelSink = std::make_shared(InOptions.LoggingConfig.OtelEndpointUri); - zen::logging::Default().SpdLogger->sinks().push_back(std::move(OtelSink)); + logging::SinkPtr OtelSink(new zen::logging::OtelHttpProtobufSink(InOptions.LoggingConfig.OtelEndpointUri)); + zen::logging::Default()->AddSink(std::move(OtelSink)); } #endif @@ -91,9 +85,10 @@ InitializeServerLogging(const ZenServerConfig& InOptions, bool WithCacheService) const zen::Oid ServerSessionId = zen::GetSessionId(); - spdlog::apply_all([&](auto Logger) { + static constinit logging::LogPoint SessionIdPoint{{}, logging::Info, "server session id: {}"}; + logging::Registry::Instance().ApplyAll([&](auto Logger) { ZEN_MEMSCOPE(ELLMTag::Logging); - Logger->info("server session id: {}", ServerSessionId); + Logger->Log(SessionIdPoint, fmt::make_format_args(ServerSessionId)); }); } diff --git a/src/zenserver/diag/otlphttp.cpp b/src/zenserver/diag/otlphttp.cpp index d62ccccb6..1434c9331 100644 --- a/src/zenserver/diag/otlphttp.cpp +++ b/src/zenserver/diag/otlphttp.cpp @@ -53,7 +53,7 @@ OtelHttpProtobufSink::TraceRecorder::RecordSpans(zen::otel::TraceId Trace, std:: } void -OtelHttpProtobufSink::log(const spdlog::details::log_msg& Msg) +OtelHttpProtobufSink::Log(const LogMessage& Msg) { { std::string Data = m_Encoder.FormatOtelProtobuf(Msg); @@ -74,7 +74,7 @@ OtelHttpProtobufSink::log(const spdlog::details::log_msg& Msg) } } void -OtelHttpProtobufSink::flush() +OtelHttpProtobufSink::Flush() { } diff --git a/src/zenserver/diag/otlphttp.h b/src/zenserver/diag/otlphttp.h index 2281bdcc0..8254af04d 100644 --- a/src/zenserver/diag/otlphttp.h +++ b/src/zenserver/diag/otlphttp.h @@ -3,7 +3,7 @@ #pragma once -#include +#include #include #include #include @@ -14,12 +14,12 @@ namespace zen::logging { /** - * OTLP/HTTP sink for spdlog + * OTLP/HTTP sink for logging * * Sends log messages and traces to an OpenTelemetry collector via OTLP over HTTP */ -class OtelHttpProtobufSink : public spdlog::sinks::sink +class OtelHttpProtobufSink : public Sink { public: // Note that this URI should be the base URI of the OTLP HTTP endpoint, e.g. @@ -31,10 +31,9 @@ public: OtelHttpProtobufSink& operator=(const OtelHttpProtobufSink&) = delete; private: - virtual void log(const spdlog::details::log_msg& Msg) override; - virtual void flush() override; - virtual void set_pattern(const std::string& pattern) override { ZEN_UNUSED(pattern); } - virtual void set_formatter(std::unique_ptr sink_formatter) override { ZEN_UNUSED(sink_formatter); } + virtual void Log(const LogMessage& Msg) override; + virtual void Flush() override; + virtual void SetFormatter(std::unique_ptr) override {} void RecordSpans(zen::otel::TraceId Trace, std::span Spans); @@ -61,4 +60,4 @@ private: } // namespace zen::logging -#endif \ No newline at end of file +#endif diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index c764cbde6..09ecc48e5 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -246,7 +246,7 @@ test_main(int argc, char** argv) # endif // ZEN_PLATFORM_WINDOWS zen::logging::InitializeLogging(); - zen::logging::SetLogLevel(zen::logging::level::Debug); + zen::logging::SetLogLevel(zen::logging::Debug); zen::MaximizeOpenFileCount(); diff --git a/src/zenserver/storage/admin/admin.cpp b/src/zenserver/storage/admin/admin.cpp index 19155e02b..c9f999c69 100644 --- a/src/zenserver/storage/admin/admin.cpp +++ b/src/zenserver/storage/admin/admin.cpp @@ -716,7 +716,7 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, "logs", [this](HttpRouterRequest& Req) { CbObjectWriter Obj; - auto LogLevel = logging::level::ToStringView(logging::GetLogLevel()); + auto LogLevel = logging::ToStringView(logging::GetLogLevel()); Obj.AddString("loglevel", std::string_view(LogLevel.data(), LogLevel.size())); Obj.AddString("Logfile", PathToUtf8(m_LogPaths.AbsLogPath)); Obj.BeginObject("cache"); @@ -767,8 +767,8 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, } if (std::string Param(Params.GetValue("loglevel")); Param.empty() == false) { - logging::level::LogLevel NewLevel = logging::level::ParseLogLevelString(Param); - std::string_view LogLevel = logging::level::ToStringView(NewLevel); + logging::LogLevel NewLevel = logging::ParseLogLevelString(Param); + std::string_view LogLevel = logging::ToStringView(NewLevel); if (LogLevel != Param) { return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest, diff --git a/src/zenstore/projectstore.cpp b/src/zenstore/projectstore.cpp index 3f705d12c..1706c9105 100644 --- a/src/zenstore/projectstore.cpp +++ b/src/zenstore/projectstore.cpp @@ -4360,7 +4360,7 @@ ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcMa , m_DiskWriteBlocker(Gc.GetDiskWriteBlocker()) { ZEN_INFO("initializing project store at '{}'", m_ProjectBasePath); - // m_Log.set_level(spdlog::level::debug); + // m_Log.SetLogLevel(zen::logging::Debug); m_Gc.AddGcStorage(this); m_Gc.AddGcReferencer(*this); m_Gc.AddGcReferenceLocker(*this); diff --git a/src/zentelemetry/include/zentelemetry/otlpencoder.h b/src/zentelemetry/include/zentelemetry/otlpencoder.h index ed6665781..f280aa9ec 100644 --- a/src/zentelemetry/include/zentelemetry/otlpencoder.h +++ b/src/zentelemetry/include/zentelemetry/otlpencoder.h @@ -13,9 +13,9 @@ # include # include -namespace spdlog { namespace details { - struct log_msg; -}} // namespace spdlog::details +namespace zen::logging { +struct LogMessage; +} // namespace zen::logging namespace zen::otel { enum class Resource : protozero::pbf_tag_type; @@ -46,7 +46,7 @@ public: void AddResourceAttribute(const std::string_view& Key, const std::string_view& Value); void AddResourceAttribute(const std::string_view& Key, int64_t Value); - std::string FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const; + std::string FormatOtelProtobuf(const logging::LogMessage& Msg) const; std::string FormatOtelMetrics() const; std::string FormatOtelTrace(zen::otel::TraceId Trace, std::span Spans) const; diff --git a/src/zentelemetry/otlpencoder.cpp b/src/zentelemetry/otlpencoder.cpp index 677545066..5477c5381 100644 --- a/src/zentelemetry/otlpencoder.cpp +++ b/src/zentelemetry/otlpencoder.cpp @@ -3,9 +3,9 @@ #include "zentelemetry/otlpencoder.h" #include +#include #include -#include #include #include @@ -29,49 +29,49 @@ OtlpEncoder::~OtlpEncoder() } static int -MapSeverity(const spdlog::level::level_enum Level) +MapSeverity(const logging::LogLevel Level) { switch (Level) { - case spdlog::level::critical: + case logging::Critical: return otel::SEVERITY_NUMBER_FATAL; - case spdlog::level::err: + case logging::Err: return otel::SEVERITY_NUMBER_ERROR; - case spdlog::level::warn: + case logging::Warn: return otel::SEVERITY_NUMBER_WARN; - case spdlog::level::info: + case logging::Info: return otel::SEVERITY_NUMBER_INFO; - case spdlog::level::debug: + case logging::Debug: return otel::SEVERITY_NUMBER_DEBUG; default: - case spdlog::level::trace: + case logging::Trace: return otel::SEVERITY_NUMBER_TRACE; } } static const char* -MapSeverityText(const spdlog::level::level_enum Level) +MapSeverityText(const logging::LogLevel Level) { switch (Level) { - case spdlog::level::critical: + case logging::Critical: return "fatal"; - case spdlog::level::err: + case logging::Err: return "error"; - case spdlog::level::warn: + case logging::Warn: return "warn"; - case spdlog::level::info: + case logging::Info: return "info"; - case spdlog::level::debug: + case logging::Debug: return "debug"; default: - case spdlog::level::trace: + case logging::Trace: return "trace"; } } std::string -OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const +OtlpEncoder::FormatOtelProtobuf(const logging::LogMessage& Msg) const { std::string Data; @@ -98,7 +98,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const protozero::pbf_builder IsBuilder{SlBuilder, otel::ScopeLogs::required_InstrumentationScope_scope}; - IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.logger_name.data(), Msg.logger_name.size()); + IsBuilder.add_string(otel::InstrumentationScope::string_name, Msg.GetLoggerName().data(), Msg.GetLoggerName().size()); } // LogRecord log_records @@ -106,13 +106,13 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const protozero::pbf_builder LrBuilder{SlBuilder, otel::ScopeLogs::required_repeated_LogRecord_log_records}; LrBuilder.add_fixed64(otel::LogRecord::required_fixed64_time_unix_nano, - std::chrono::duration_cast(Msg.time.time_since_epoch()).count()); + std::chrono::duration_cast(Msg.GetTime().time_since_epoch()).count()); - const int Severity = MapSeverity(Msg.level); + const int Severity = MapSeverity(Msg.GetLevel()); LrBuilder.add_enum(otel::LogRecord::optional_SeverityNumber_severity_number, Severity); - LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.level)); + LrBuilder.add_string(otel::LogRecord::optional_string_severity_text, MapSeverityText(Msg.GetLevel())); otel::TraceId TraceId; const otel::SpanId SpanId = otel::Span::GetCurrentSpanId(TraceId); @@ -127,7 +127,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const { protozero::pbf_builder BodyBuilder{LrBuilder, otel::LogRecord::optional_anyvalue_body}; - BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.payload.data(), Msg.payload.size()); + BodyBuilder.add_string(otel::AnyValue::string_string_value, Msg.GetPayload().data(), Msg.GetPayload().size()); } // attributes @@ -139,7 +139,7 @@ OtlpEncoder::FormatOtelProtobuf(const spdlog::details::log_msg& Msg) const { protozero::pbf_builder AvBuilder{KvBuilder, otel::KeyValue::AnyValue_value}; - AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.thread_id); + AvBuilder.add_int64(otel::AnyValue::int64_int_value, Msg.GetThreadId()); } } } diff --git a/src/zentelemetry/xmake.lua b/src/zentelemetry/xmake.lua index 7739c0a08..cd9a18ec4 100644 --- a/src/zentelemetry/xmake.lua +++ b/src/zentelemetry/xmake.lua @@ -6,5 +6,5 @@ target('zentelemetry') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "protozero", "spdlog") + add_deps("zencore", "protozero") add_deps("robin-map") diff --git a/src/zenutil/config/commandlineoptions.cpp b/src/zenutil/config/commandlineoptions.cpp index 2344354b3..25f5522d8 100644 --- a/src/zenutil/config/commandlineoptions.cpp +++ b/src/zenutil/config/commandlineoptions.cpp @@ -2,6 +2,7 @@ #include +#include #include #include diff --git a/src/zenutil/config/loggingconfig.cpp b/src/zenutil/config/loggingconfig.cpp index 9ec816b1b..5092c60aa 100644 --- a/src/zenutil/config/loggingconfig.cpp +++ b/src/zenutil/config/loggingconfig.cpp @@ -21,13 +21,13 @@ ZenLoggingCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenLoggingCon ("log-id", "Specify id for adding context to log output", cxxopts::value(LoggingConfig.LogId)) ("quiet", "Configure console logger output to level WARN", cxxopts::value(LoggingConfig.QuietConsole)->default_value("false")) ("noconsole", "Disable console logging", cxxopts::value(LoggingConfig.NoConsoleOutput)->default_value("false")) - ("log-trace", "Change selected loggers to level TRACE", cxxopts::value(LoggingConfig.Loggers[logging::level::Trace])) - ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value(LoggingConfig.Loggers[logging::level::Debug])) - ("log-info", "Change selected loggers to level INFO", cxxopts::value(LoggingConfig.Loggers[logging::level::Info])) - ("log-warn", "Change selected loggers to level WARN", cxxopts::value(LoggingConfig.Loggers[logging::level::Warn])) - ("log-error", "Change selected loggers to level ERROR", cxxopts::value(LoggingConfig.Loggers[logging::level::Err])) - ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value(LoggingConfig.Loggers[logging::level::Critical])) - ("log-off", "Change selected loggers to level OFF", cxxopts::value(LoggingConfig.Loggers[logging::level::Off])) + ("log-trace", "Change selected loggers to level TRACE", cxxopts::value(LoggingConfig.Loggers[logging::Trace])) + ("log-debug", "Change selected loggers to level DEBUG", cxxopts::value(LoggingConfig.Loggers[logging::Debug])) + ("log-info", "Change selected loggers to level INFO", cxxopts::value(LoggingConfig.Loggers[logging::Info])) + ("log-warn", "Change selected loggers to level WARN", cxxopts::value(LoggingConfig.Loggers[logging::Warn])) + ("log-error", "Change selected loggers to level ERROR", cxxopts::value(LoggingConfig.Loggers[logging::Err])) + ("log-critical", "Change selected loggers to level CRITICAL", cxxopts::value(LoggingConfig.Loggers[logging::Critical])) + ("log-off", "Change selected loggers to level OFF", cxxopts::value(LoggingConfig.Loggers[logging::Off])) ("otlp-endpoint", "OpenTelemetry endpoint URI (e.g http://localhost:4318)", cxxopts::value(LoggingConfig.OtelEndpointUri)) ; // clang-format on @@ -47,7 +47,7 @@ ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) if (LoggingConfig.QuietConsole) { bool HasExplicitConsoleLevel = false; - for (int i = 0; i < logging::level::LogLevelCount; ++i) + for (int i = 0; i < logging::LogLevelCount; ++i) { if (LoggingConfig.Loggers[i].find("console") != std::string::npos) { @@ -58,7 +58,7 @@ ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) if (!HasExplicitConsoleLevel) { - std::string& WarnLoggers = LoggingConfig.Loggers[logging::level::Warn]; + std::string& WarnLoggers = LoggingConfig.Loggers[logging::Warn]; if (!WarnLoggers.empty()) { WarnLoggers += ","; @@ -67,9 +67,9 @@ ApplyLoggingOptions(cxxopts::Options& options, ZenLoggingConfig& LoggingConfig) } } - for (int i = 0; i < logging::level::LogLevelCount; ++i) + for (int i = 0; i < logging::LogLevelCount; ++i) { - logging::ConfigureLogLevels(logging::level::LogLevel(i), LoggingConfig.Loggers[i]); + logging::ConfigureLogLevels(logging::LogLevel(i), LoggingConfig.Loggers[i]); } logging::RefreshLogLevels(); } diff --git a/src/zenutil/include/zenutil/config/loggingconfig.h b/src/zenutil/include/zenutil/config/loggingconfig.h index 6d6f64b30..b55b2d9f7 100644 --- a/src/zenutil/include/zenutil/config/loggingconfig.h +++ b/src/zenutil/include/zenutil/config/loggingconfig.h @@ -17,7 +17,7 @@ struct ZenLoggingConfig bool NoConsoleOutput = false; // Control default use of stdout for diagnostics bool QuietConsole = false; // Configure console logger output to level WARN std::filesystem::path AbsLogFile; // Absolute path to main log file - std::string Loggers[logging::level::LogLevelCount]; + std::string Loggers[logging::LogLevelCount]; std::string LogId; // Id for tagging log output std::string OtelEndpointUri; // OpenTelemetry endpoint URI }; diff --git a/src/zenutil/include/zenutil/logging.h b/src/zenutil/include/zenutil/logging.h index 85ddc86cd..95419c274 100644 --- a/src/zenutil/include/zenutil/logging.h +++ b/src/zenutil/include/zenutil/logging.h @@ -3,19 +3,12 @@ #pragma once #include +#include #include #include #include -namespace spdlog::sinks { -class sink; -} - -namespace spdlog { -using sink_ptr = std::shared_ptr; -} - ////////////////////////////////////////////////////////////////////////// // // Logging utilities @@ -45,6 +38,6 @@ void FinishInitializeLogging(const LoggingOptions& LoggingOptions); void InitializeLogging(const LoggingOptions& LoggingOptions); void ShutdownLogging(); -spdlog::sink_ptr GetFileSink(); +logging::SinkPtr GetFileSink(); } // namespace zen diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h index 9f245becd..33cb94dae 100644 --- a/src/zenutil/include/zenutil/logging/fullformatter.h +++ b/src/zenutil/include/zenutil/logging/fullformatter.h @@ -2,21 +2,19 @@ #pragma once +#include +#include #include #include #include -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END - namespace zen::logging { -class full_formatter final : public spdlog::formatter +class FullFormatter final : public Formatter { public: - full_formatter(std::string_view LogId, std::chrono::time_point Epoch) + FullFormatter(std::string_view LogId, std::chrono::time_point Epoch) : m_Epoch(Epoch) , m_LogId(LogId) , m_LinePrefix(128, ' ') @@ -24,16 +22,19 @@ public: { } - full_formatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} - virtual std::unique_ptr clone() const override + virtual std::unique_ptr Clone() const override { ZEN_MEMSCOPE(ELLMTag::Logging); - // Note: this does not properly clone m_UseFullDate - return std::make_unique(m_LogId, m_Epoch); + if (m_UseFullDate) + { + return std::make_unique(m_LogId); + } + return std::make_unique(m_LogId, m_Epoch); } - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutBuffer) override + virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -44,38 +45,38 @@ public: std::chrono::seconds TimestampSeconds; - std::chrono::milliseconds millis; + std::chrono::milliseconds Millis; if (m_UseFullDate) { - TimestampSeconds = std::chrono::duration_cast(msg.time.time_since_epoch()); + TimestampSeconds = std::chrono::duration_cast(Msg.GetTime().time_since_epoch()); if (TimestampSeconds != m_LastLogSecs) { RwLock::ExclusiveLockScope _(m_TimestampLock); m_LastLogSecs = TimestampSeconds; - m_CachedLocalTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); + m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); m_CachedDatetime.clear(); m_CachedDatetime.push_back('['); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); m_CachedDatetime.push_back(' '); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); + helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); m_CachedDatetime.push_back('.'); } - millis = spdlog::details::fmt_helper::time_fraction(msg.time); + Millis = helpers::TimeFraction(Msg.GetTime()); } else { - auto ElapsedTime = msg.time - m_Epoch; + auto ElapsedTime = Msg.GetTime() - m_Epoch; TimestampSeconds = std::chrono::duration_cast(ElapsedTime); if (m_CacheTimestamp.load() != TimestampSeconds) @@ -93,15 +94,15 @@ public: m_CachedDatetime.clear(); m_CachedDatetime.push_back('['); - spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime); + helpers::Pad2(LogHours, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime); + helpers::Pad2(LogMins, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime); + helpers::Pad2(LogSecs, m_CachedDatetime); m_CachedDatetime.push_back('.'); } - millis = std::chrono::duration_cast(ElapsedTime - TimestampSeconds); + Millis = std::chrono::duration_cast(ElapsedTime - TimestampSeconds); } { @@ -109,44 +110,43 @@ public: OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); } - spdlog::details::fmt_helper::pad3(static_cast(millis.count()), OutBuffer); + helpers::Pad3(static_cast(Millis.count()), OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); if (!m_LogId.empty()) { OutBuffer.push_back('['); - spdlog::details::fmt_helper::append_string_view(m_LogId, OutBuffer); + helpers::AppendStringView(m_LogId, OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } // append logger name if exists - if (msg.logger_name.size() > 0) + if (Msg.GetLoggerName().size() > 0) { OutBuffer.push_back('['); - spdlog::details::fmt_helper::append_string_view(msg.logger_name, OutBuffer); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } OutBuffer.push_back('['); // wrap the level name with color - msg.color_range_start = OutBuffer.size(); - spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), OutBuffer); - msg.color_range_end = OutBuffer.size(); + Msg.ColorRangeStart = OutBuffer.size(); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); + Msg.ColorRangeEnd = OutBuffer.size(); OutBuffer.push_back(']'); OutBuffer.push_back(' '); // add source location if present - if (!msg.source.empty()) + if (Msg.GetSource()) { OutBuffer.push_back('['); - const char* filename = - spdlog::details::short_filename_formatter::basename(msg.source.filename); - spdlog::details::fmt_helper::append_string_view(filename, OutBuffer); + const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); + helpers::AppendStringView(Filename, OutBuffer); OutBuffer.push_back(':'); - spdlog::details::fmt_helper::append_int(msg.source.line, OutBuffer); + helpers::AppendInt(Msg.GetSource().Line, OutBuffer); OutBuffer.push_back(']'); OutBuffer.push_back(' '); } @@ -156,8 +156,9 @@ public: const size_t LinePrefixCount = Min(OutBuffer.size(), m_LinePrefix.size()); - auto ItLineBegin = msg.payload.begin(); - auto ItMessageEnd = msg.payload.end(); + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); bool IsFirstline = true; { @@ -170,9 +171,9 @@ public: } else { - spdlog::details::fmt_helper::append_string_view(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); + helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); } - spdlog::details::fmt_helper::append_string_view(spdlog::string_view_t(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); }; while (ItLineEnd != ItMessageEnd) @@ -187,7 +188,7 @@ public: if (ItLineBegin != ItMessageEnd) { EmitLine(); - spdlog::details::fmt_helper::append_string_view("\n"sv, OutBuffer); + helpers::AppendStringView("\n"sv, OutBuffer); } } } @@ -197,7 +198,7 @@ private: std::tm m_CachedLocalTm; std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; std::atomic m_CacheTimestamp{std::chrono::seconds(87654321)}; - spdlog::memory_buf_t m_CachedDatetime; + MemoryBuffer m_CachedDatetime; std::string m_LogId; std::string m_LinePrefix; bool m_UseFullDate = true; diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h index 3f660e421..216b1b5e5 100644 --- a/src/zenutil/include/zenutil/logging/jsonformatter.h +++ b/src/zenutil/include/zenutil/logging/jsonformatter.h @@ -2,27 +2,26 @@ #pragma once +#include +#include #include #include #include - -ZEN_THIRD_PARTY_INCLUDES_START -#include -ZEN_THIRD_PARTY_INCLUDES_END +#include namespace zen::logging { using namespace std::literals; -class json_formatter final : public spdlog::formatter +class JsonFormatter final : public Formatter { public: - json_formatter(std::string_view LogId) : m_LogId(LogId) {} + JsonFormatter(std::string_view LogId) : m_LogId(LogId) {} - virtual std::unique_ptr clone() const override { return std::make_unique(m_LogId); } + virtual std::unique_ptr Clone() const override { return std::make_unique(m_LogId); } - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override + virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -30,141 +29,132 @@ public: using std::chrono::milliseconds; using std::chrono::seconds; - auto secs = std::chrono::duration_cast(msg.time.time_since_epoch()); - if (secs != m_LastLogSecs) + auto Secs = std::chrono::duration_cast(Msg.GetTime().time_since_epoch()); + if (Secs != m_LastLogSecs) { - m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); - m_LastLogSecs = secs; - } - - const auto& tm_time = m_CachedTm; + RwLock::ExclusiveLockScope _(m_TimestampLock); + m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_LastLogSecs = Secs; - // cache the date/time part for the next second. - - if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) - { + // cache the date/time part for the next second. m_CachedDatetime.clear(); - spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime); m_CachedDatetime.push_back('-'); - spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime); m_CachedDatetime.push_back(' '); - spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime); m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime); m_CachedDatetime.push_back('.'); - - m_CacheTimestamp = secs; } - dest.append("{"sv); - dest.append("\"time\": \""sv); - dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - auto millis = spdlog::details::fmt_helper::time_fraction(msg.time); - spdlog::details::fmt_helper::pad3(static_cast(millis.count()), dest); - dest.append("\", "sv); + helpers::AppendStringView("{"sv, Dest); + helpers::AppendStringView("\"time\": \""sv, Dest); + { + RwLock::SharedLockScope _(m_TimestampLock); + Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + } + auto Millis = helpers::TimeFraction(Msg.GetTime()); + helpers::Pad3(static_cast(Millis.count()), Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"status\": \""sv); - dest.append(spdlog::level::to_string_view(msg.level)); - dest.append("\", "sv); + helpers::AppendStringView("\"status\": \""sv, Dest); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"source\": \""sv); - dest.append("zenserver"sv); - dest.append("\", "sv); + helpers::AppendStringView("\"source\": \""sv, Dest); + helpers::AppendStringView("zenserver"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); - dest.append("\"service\": \""sv); - dest.append("zencache"sv); - dest.append("\", "sv); + helpers::AppendStringView("\"service\": \""sv, Dest); + helpers::AppendStringView("zencache"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); if (!m_LogId.empty()) { - dest.append("\"id\": \""sv); - dest.append(m_LogId); - dest.append("\", "sv); + helpers::AppendStringView("\"id\": \""sv, Dest); + helpers::AppendStringView(m_LogId, Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (msg.logger_name.size() > 0) + if (Msg.GetLoggerName().size() > 0) { - dest.append("\"logger.name\": \""sv); - dest.append(msg.logger_name); - dest.append("\", "sv); + helpers::AppendStringView("\"logger.name\": \""sv, Dest); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (msg.thread_id != 0) + if (Msg.GetThreadId() != 0) { - dest.append("\"logger.thread_name\": \""sv); - spdlog::details::fmt_helper::pad_uint(msg.thread_id, 0, dest); - dest.append("\", "sv); + helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); + helpers::PadUint(Msg.GetThreadId(), 0, Dest); + helpers::AppendStringView("\", "sv, Dest); } - if (!msg.source.empty()) + if (Msg.GetSource()) { - dest.append("\"file\": \""sv); - WriteEscapedString( - dest, - spdlog::details::short_filename_formatter::basename(msg.source.filename)); - dest.append("\","sv); - - dest.append("\"line\": \""sv); - dest.append(fmt::format("{}", msg.source.line)); - dest.append("\","sv); - - dest.append("\"logger.method_name\": \""sv); - WriteEscapedString(dest, msg.source.funcname); - dest.append("\", "sv); + helpers::AppendStringView("\"file\": \""sv, Dest); + WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); + helpers::AppendStringView("\","sv, Dest); + + helpers::AppendStringView("\"line\": \""sv, Dest); + helpers::AppendInt(Msg.GetSource().Line, Dest); + helpers::AppendStringView("\","sv, Dest); } - dest.append("\"message\": \""sv); - WriteEscapedString(dest, msg.payload); - dest.append("\""sv); + helpers::AppendStringView("\"message\": \""sv, Dest); + WriteEscapedString(Dest, Msg.GetPayload()); + helpers::AppendStringView("\""sv, Dest); - dest.append("}\n"sv); + helpers::AppendStringView("}\n"sv, Dest); } private: - static inline const std::unordered_map SpecialCharacterMap{{'\b', "\\b"sv}, - {'\f', "\\f"sv}, - {'\n', "\\n"sv}, - {'\r', "\\r"sv}, - {'\t', "\\t"sv}, - {'"', "\\\""sv}, - {'\\', "\\\\"sv}}; - - static void WriteEscapedString(spdlog::memory_buf_t& dest, const spdlog::string_view_t& payload) + static inline const std::unordered_map s_SpecialCharacterMap{{'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}}; + + static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text) { - const char* RangeStart = payload.begin(); - for (const char* It = RangeStart; It != payload.end(); ++It) + const char* RangeStart = Text.data(); + const char* End = Text.data() + Text.size(); + for (const char* It = RangeStart; It != End; ++It) { - if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end()) { if (RangeStart != It) { - dest.append(RangeStart, It); + Dest.append(RangeStart, It); } - dest.append(SpecialIt->second); + helpers::AppendStringView(SpecialIt->second, Dest); RangeStart = It + 1; } } - if (RangeStart != payload.end()) + if (RangeStart != End) { - dest.append(RangeStart, payload.end()); + Dest.append(RangeStart, End); } }; std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; std::chrono::seconds m_LastLogSecs{0}; - std::chrono::seconds m_CacheTimestamp{0}; - spdlog::memory_buf_t m_CachedDatetime; + MemoryBuffer m_CachedDatetime; std::string m_LogId; + RwLock m_TimestampLock; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h index 8901b7779..cebc5b110 100644 --- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h +++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h @@ -3,14 +3,11 @@ #pragma once #include +#include +#include +#include #include -ZEN_THIRD_PARTY_INCLUDES_START -#include -#include -#include -ZEN_THIRD_PARTY_INCLUDES_END - #include #include @@ -19,13 +16,14 @@ namespace zen::logging { // Basically the same functionality as spdlog::sinks::rotating_file_sink with the biggest difference // being that it just ignores any errors when writing/rotating files and keeps chugging on. // It will keep trying to log, and if it starts to work it will continue to log. -class RotatingFileSink : public spdlog::sinks::sink +class RotatingFileSink : public Sink { public: RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false) : m_BaseFilename(BaseFilename) , m_MaxSize(MaxSize) , m_MaxFiles(MaxFiles) + , m_Formatter(std::make_unique()) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -76,18 +74,21 @@ public: RotatingFileSink& operator=(const RotatingFileSink&) = delete; RotatingFileSink& operator=(RotatingFileSink&&) = delete; - virtual void log(const spdlog::details::log_msg& msg) override + virtual void Log(const LogMessage& Msg) override { ZEN_MEMSCOPE(ELLMTag::Logging); try { - spdlog::memory_buf_t Formatted; - if (TrySinkIt(msg, Formatted)) + MemoryBuffer Formatted; + if (TrySinkIt(Msg, Formatted)) { return; } - while (true) + + // This intentionally has no limit on the number of retries, see + // comment above. + for (;;) { { RwLock::ExclusiveLockScope RotateLock(m_Lock); @@ -113,7 +114,7 @@ public: // Silently eat errors } } - virtual void flush() override + virtual void Flush() override { if (!m_NeedFlush) { @@ -138,28 +139,14 @@ public: m_NeedFlush = false; } - virtual void set_pattern(const std::string& pattern) override + virtual void SetFormatter(std::unique_ptr InFormatter) override { ZEN_MEMSCOPE(ELLMTag::Logging); try { RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = spdlog::details::make_unique(pattern); - } - catch (const std::exception&) - { - // Silently eat errors - } - } - virtual void set_formatter(std::unique_ptr sink_formatter) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = std::move(sink_formatter); + m_Formatter = std::move(InFormatter); } catch (const std::exception&) { @@ -186,11 +173,17 @@ private: return; } - // If we fail to rotate, try extending the current log file m_CurrentSize = m_CurrentFile.FileSize(OutEc); + if (OutEc) + { + // FileSize failed but we have an open file — reset to 0 + // so we can at least attempt writes from the start + m_CurrentSize = 0; + OutEc.clear(); + } } - bool TrySinkIt(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& OutFormatted) + bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -199,15 +192,15 @@ private: { return false; } - m_Formatter->format(msg, OutFormatted); - size_t add_size = OutFormatted.size(); - size_t write_pos = m_CurrentSize.fetch_add(add_size); - if (write_pos + add_size > m_MaxSize) + m_Formatter->Format(Msg, OutFormatted); + size_t AddSize = OutFormatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) { return false; } std::error_code Ec; - m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), write_pos, Ec); + m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); if (Ec) { return false; @@ -216,7 +209,7 @@ private: return true; } - bool TrySinkIt(const spdlog::memory_buf_t& Formatted) + bool TrySinkIt(const MemoryBuffer& Formatted) { ZEN_MEMSCOPE(ELLMTag::Logging); @@ -225,15 +218,15 @@ private: { return false; } - size_t add_size = Formatted.size(); - size_t write_pos = m_CurrentSize.fetch_add(add_size); - if (write_pos + add_size > m_MaxSize) + size_t AddSize = Formatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) { return false; } std::error_code Ec; - m_CurrentFile.Write(Formatted.data(), Formatted.size(), write_pos, Ec); + m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); if (Ec) { return false; @@ -242,14 +235,14 @@ private: return true; } - RwLock m_Lock; - const std::filesystem::path m_BaseFilename; - std::unique_ptr m_Formatter; - std::atomic_size_t m_CurrentSize; - const std::size_t m_MaxSize; - const std::size_t m_MaxFiles; - BasicFile m_CurrentFile; - std::atomic m_NeedFlush = false; + RwLock m_Lock; + const std::filesystem::path m_BaseFilename; + const std::size_t m_MaxSize; + const std::size_t m_MaxFiles; + std::unique_ptr m_Formatter; + std::atomic_size_t m_CurrentSize; + BasicFile m_CurrentFile; + std::atomic m_NeedFlush = false; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/testformatter.h b/src/zenutil/include/zenutil/logging/testformatter.h deleted file mode 100644 index 0b0c191fb..000000000 --- a/src/zenutil/include/zenutil/logging/testformatter.h +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#include - -namespace zen::logging { - -class full_test_formatter final : public spdlog::formatter -{ -public: - full_test_formatter(std::string_view LogId, std::chrono::time_point Epoch) : m_Epoch(Epoch), m_LogId(LogId) - { - } - - virtual std::unique_ptr clone() const override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - return std::make_unique(m_LogId, m_Epoch); - } - - static constexpr bool UseDate = false; - - virtual void format(const spdlog::details::log_msg& msg, spdlog::memory_buf_t& dest) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - using namespace std::literals; - - if constexpr (UseDate) - { - auto secs = std::chrono::duration_cast(msg.time.time_since_epoch()); - if (secs != m_LastLogSecs) - { - m_CachedTm = spdlog::details::os::localtime(spdlog::log_clock::to_time_t(msg.time)); - m_LastLogSecs = secs; - } - } - - const auto& tm_time = m_CachedTm; - - // cache the date/time part for the next second. - auto duration = msg.time - m_Epoch; - auto secs = std::chrono::duration_cast(duration); - - if (m_CacheTimestamp != secs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - - if constexpr (UseDate) - { - spdlog::details::fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - - spdlog::details::fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - spdlog::details::fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); - } - else - { - int Count = int(secs.count()); - - const int LogSecs = Count % 60; - Count /= 60; - - const int LogMins = Count % 60; - Count /= 60; - - const int LogHours = Count; - - spdlog::details::fmt_helper::pad2(LogHours, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogMins, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - spdlog::details::fmt_helper::pad2(LogSecs, m_CachedDatetime); - } - - m_CachedDatetime.push_back('.'); - - m_CacheTimestamp = secs; - } - - { - RwLock::SharedLockScope _(m_TimestampLock); - dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - - auto millis = spdlog::details::fmt_helper::time_fraction(msg.time); - spdlog::details::fmt_helper::pad3(static_cast(millis.count()), dest); - dest.push_back(']'); - dest.push_back(' '); - - if (!m_LogId.empty()) - { - dest.push_back('['); - spdlog::details::fmt_helper::append_string_view(m_LogId, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - // append logger name if exists - if (msg.logger_name.size() > 0) - { - dest.push_back('['); - spdlog::details::fmt_helper::append_string_view(msg.logger_name, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - dest.push_back('['); - // wrap the level name with color - msg.color_range_start = dest.size(); - spdlog::details::fmt_helper::append_string_view(spdlog::level::to_string_view(msg.level), dest); - msg.color_range_end = dest.size(); - dest.push_back(']'); - dest.push_back(' '); - - // add source location if present - if (!msg.source.empty()) - { - dest.push_back('['); - const char* filename = - spdlog::details::short_filename_formatter::basename(msg.source.filename); - spdlog::details::fmt_helper::append_string_view(filename, dest); - dest.push_back(':'); - spdlog::details::fmt_helper::append_int(msg.source.line, dest); - dest.push_back(']'); - dest.push_back(' '); - } - - spdlog::details::fmt_helper::append_string_view(msg.payload, dest); - spdlog::details::fmt_helper::append_string_view("\n"sv, dest); - } - -private: - std::chrono::time_point m_Epoch; - std::tm m_CachedTm; - std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; - std::chrono::seconds m_CacheTimestamp{std::chrono::seconds(87654321)}; - spdlog::memory_buf_t m_CachedDatetime; - std::string m_LogId; - RwLock m_TimestampLock; -}; - -} // namespace zen::logging diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging.cpp index 54ac30c5d..1258ca155 100644 --- a/src/zenutil/logging.cpp +++ b/src/zenutil/logging.cpp @@ -2,18 +2,15 @@ #include "zenutil/logging.h" -ZEN_THIRD_PARTY_INCLUDES_START -#include -#include -#include -#include -#include -ZEN_THIRD_PARTY_INCLUDES_END - #include #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -27,9 +24,9 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { static bool g_IsLoggingInitialized; -spdlog::sink_ptr g_FileSink; +logging::SinkPtr g_FileSink; -spdlog::sink_ptr +logging::SinkPtr GetFileSink() { return g_FileSink; @@ -52,33 +49,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) zen::logging::InitializeLogging(); zen::logging::EnableVTMode(); - bool IsAsync = LogOptions.AllowAsync; - - if (LogOptions.IsDebug) - { - IsAsync = false; - } - - if (LogOptions.IsTest) - { - IsAsync = false; - } - - if (IsAsync) - { - const int QueueSize = 8192; - const int ThreadCount = 1; - spdlog::init_thread_pool(QueueSize, ThreadCount, [&] { SetCurrentThreadName("spdlog_async"); }); - - auto AsyncSink = spdlog::create_async("main"); - zen::logging::SetDefault("main"); - } - // Sinks - spdlog::sink_ptr FileSink; - - // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + logging::SinkPtr FileSink; if (!LogOptions.AbsLogFile.empty()) { @@ -87,17 +60,17 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) zen::CreateDirectories(LogOptions.AbsLogFile.parent_path()); } - FileSink = std::make_shared(LogOptions.AbsLogFile, - /* max size */ 128 * 1024 * 1024, - /* max files */ 16, - /* rotate on open */ true); + FileSink = logging::SinkPtr(new zen::logging::RotatingFileSink(LogOptions.AbsLogFile, + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true)); if (LogOptions.AbsLogFile.extension() == ".json") { - FileSink->set_formatter(std::make_unique(LogOptions.LogId)); + FileSink->SetFormatter(std::make_unique(LogOptions.LogId)); } else { - FileSink->set_formatter(std::make_unique(LogOptions.LogId)); // this will have a date prefix + FileSink->SetFormatter(std::make_unique(LogOptions.LogId)); // this will have a date prefix } } @@ -127,7 +100,7 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) Message.push_back('\0'); // We use direct ZEN_LOG here instead of ZEN_ERROR as we don't care about *this* code location in the log - ZEN_LOG(Log(), zen::logging::level::Critical, "{}", Message.data()); + ZEN_LOG(Log(), zen::logging::Critical, "{}", Message.data()); zen::logging::FlushLogging(); } catch (const std::exception&) @@ -143,9 +116,9 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) // Default LoggerRef DefaultLogger = zen::logging::Default(); - auto& Sinks = DefaultLogger.SpdLogger->sinks(); - Sinks.clear(); + // Collect sinks into a local vector first so we can optionally wrap them + std::vector Sinks; if (LogOptions.NoConsoleOutput) { @@ -153,10 +126,10 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) } else { - auto ConsoleSink = std::make_shared(); + logging::SinkPtr ConsoleSink(new logging::AnsiColorStdoutSink()); if (LogOptions.QuietConsole) { - ConsoleSink->set_level(spdlog::level::warn); + ConsoleSink->SetLevel(logging::Warn); } Sinks.push_back(ConsoleSink); } @@ -169,40 +142,54 @@ BeginInitializeLogging(const LoggingOptions& LogOptions) #if ZEN_PLATFORM_WINDOWS if (zen::IsDebuggerPresent() && LogOptions.IsDebug) { - auto DebugSink = std::make_shared(); - DebugSink->set_level(spdlog::level::debug); + logging::SinkPtr DebugSink(new logging::MsvcSink()); + DebugSink->SetLevel(logging::Debug); Sinks.push_back(DebugSink); } #endif - spdlog::set_error_handler([](const std::string& msg) { - if (msg == std::bad_alloc().what()) - { - // Don't report out of memory in spdlog as we usually log in response to errors which will cause another OOM crashing the - // program - return; - } - // Bypass zen logging wrapping to reduce potential other error sources - if (auto ErrLogger = zen::logging::ErrorLog()) + bool IsAsync = LogOptions.AllowAsync && !LogOptions.IsDebug && !LogOptions.IsTest; + + if (IsAsync) + { + std::vector AsyncSinks; + AsyncSinks.emplace_back(new logging::AsyncSink(std::move(Sinks))); + DefaultLogger->SetSinks(std::move(AsyncSinks)); + } + else + { + DefaultLogger->SetSinks(std::move(Sinks)); + } + + static struct : logging::ErrorHandler + { + void HandleError(const std::string_view& ErrorMsg) override { + if (ErrorMsg == std::bad_alloc().what()) + { + return; + } + static constinit logging::LogPoint ErrorPoint{{}, logging::Err, "{}"}; + if (auto ErrLogger = zen::logging::ErrorLog()) + { + try + { + ErrLogger->Log(ErrorPoint, fmt::make_format_args(ErrorMsg)); + } + catch (const std::exception&) + { + } + } try { - ErrLogger.SpdLogger->log(spdlog::level::err, msg); + Log()->Log(ErrorPoint, fmt::make_format_args(ErrorMsg)); } catch (const std::exception&) { - // Just ignore any errors when in error handler } } - try - { - Log().SpdLogger->error(msg); - } - catch (const std::exception&) - { - // Just ignore any errors when in error handler - } - }); + } s_ErrorHandler; + logging::Registry::Instance().SetErrorHandler(&s_ErrorHandler); g_FileSink = std::move(FileSink); } @@ -212,24 +199,24 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) { ZEN_MEMSCOPE(ELLMTag::Logging); - logging::level::LogLevel LogLevel = logging::level::Info; + logging::LogLevel LogLevel = logging::Info; if (LogOptions.IsDebug) { - LogLevel = logging::level::Debug; + LogLevel = logging::Debug; } if (LogOptions.IsTest || LogOptions.IsVerbose) { - LogLevel = logging::level::Trace; + LogLevel = logging::Trace; } // Configure all registered loggers according to settings logging::RefreshLogLevels(LogLevel); - spdlog::flush_on(spdlog::level::err); - spdlog::flush_every(std::chrono::seconds{2}); - spdlog::set_formatter(std::make_unique( + logging::Registry::Instance().FlushOn(logging::Err); + logging::Registry::Instance().FlushEvery(std::chrono::seconds{2}); + logging::Registry::Instance().SetFormatter(std::make_unique( LogOptions.LogId, std::chrono::system_clock::now() - std::chrono::milliseconds(GetTimeSinceProcessStart()))); // default to duration prefix @@ -242,16 +229,17 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) { if (LogOptions.AbsLogFile.extension() == ".json") { - g_FileSink->set_formatter(std::make_unique(LogOptions.LogId)); + g_FileSink->SetFormatter(std::make_unique(LogOptions.LogId)); } else { - g_FileSink->set_formatter(std::make_unique(LogOptions.LogId)); // this will have a date prefix + g_FileSink->SetFormatter(std::make_unique(LogOptions.LogId)); // this will have a date prefix } const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); - spdlog::apply_all([&](auto Logger) { Logger->info("log starting at {}", StartLogTime); }); + static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); } g_IsLoggingInitialized = true; @@ -268,7 +256,7 @@ ShutdownLogging() zen::logging::ShutdownLogging(); - g_FileSink.reset(); + g_FileSink = nullptr; } } // namespace zen diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index bc33adf9e..1d5be5977 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -6,7 +6,7 @@ target('zenutil') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "zenhttp", "spdlog") + add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") diff --git a/src/zenvfs/xmake.lua b/src/zenvfs/xmake.lua index 7f790c2d4..47665a5d5 100644 --- a/src/zenvfs/xmake.lua +++ b/src/zenvfs/xmake.lua @@ -6,5 +6,5 @@ target('zenvfs') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "spdlog") + add_deps("zencore") -- cgit v1.2.3 From 07649104761ee910b667adb2b865c4e2fd0979c9 Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 9 Mar 2026 12:03:25 +0100 Subject: added auto-detection logic for console colour output (#817) Add auto-detection of colour support to `AnsicolourStdoutSink`. **New `colorMode` enum** (`On`, `Off`, `Auto`) added to the header, accepted by the `AnsicolorStdoutSink` constructor. Defaults to `Auto`, so all existing call sites are unaffected. **`Auto` mode detection logic** (in `IscolourTerminal()`): 1. **TTY check** -- if stdout is not a terminal, colour is disabled. 2. **`NO_COLOR`** -- respects the no-colour.org convention. If set, colour is disabled. 3. **`COLORTERM`** -- if set (e.g. `truecolour`, `24bit`), colour is enabled. 4. **`TERM`** -- rejects `dumb`; accepts known colour-capable terminals via substring match: `alacritty`, `ansi`, `colour`, `console`, `cygwin`, `gnome`, `konsole`, `kterm`, `linux`, `msys`, `putty`, `rxvt`, `screen`, `tmux`, `vt100`, `vt102`, `xterm`. Substring matching covers variants like `xterm-256color` and `rxvt-unicode`. 5. **Fallback** -- Windows defaults to colour enabled (modern console supports ANSI natively); other platforms default to disabled. When colour is disabled, ANSI escape sequences are omitted entirely from the output. NOTE: this doesn't currently apply to all paths which do logging in zen as they may be determining their colour output mode separately from `AnsicolorStdoutSink`. --- .../include/zencore/logging/ansicolorsink.h | 9 +- src/zencore/logging/ansicolorsink.cpp | 103 ++++++++++++++++++++- 2 files changed, 107 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h index 9f859e8d7..5060a8393 100644 --- a/src/zencore/include/zencore/logging/ansicolorsink.h +++ b/src/zencore/include/zencore/logging/ansicolorsink.h @@ -8,10 +8,17 @@ namespace zen::logging { +enum class ColorMode +{ + On, + Off, + Auto +}; + class AnsiColorStdoutSink : public Sink { public: - AnsiColorStdoutSink(); + explicit AnsiColorStdoutSink(ColorMode Mode = ColorMode::Auto); ~AnsiColorStdoutSink() override; void Log(const LogMessage& Msg) override; diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp index 9b9959862..540d22359 100644 --- a/src/zencore/logging/ansicolorsink.cpp +++ b/src/zencore/logging/ansicolorsink.cpp @@ -5,8 +5,19 @@ #include #include +#include #include +#if defined(_WIN32) +# include +# define ZEN_ISATTY _isatty +# define ZEN_FILENO _fileno +#else +# include +# define ZEN_ISATTY isatty +# define ZEN_FILENO fileno +#endif + namespace zen::logging { // Default formatter replicating spdlog's %+ pattern: @@ -98,7 +109,90 @@ GetColorForLevel(LogLevel InLevel) struct AnsiColorStdoutSink::Impl { - Impl() : m_Formatter(std::make_unique()) {} + explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique()), m_UseColor(ResolveColorMode(Mode)) {} + + static bool IsColorTerminal() + { + // If stdout is not a TTY, no color + if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) + { + return false; + } + + // NO_COLOR convention (https://no-color.org/) + if (std::getenv("NO_COLOR") != nullptr) + { + return false; + } + + // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") + if (std::getenv("COLORTERM") != nullptr) + { + return true; + } + + // Check TERM for known color-capable values + const char* Term = std::getenv("TERM"); + if (Term != nullptr) + { + std::string_view TermView(Term); + // "dumb" terminals do not support color + if (TermView == "dumb") + { + return false; + } + // Match against known color-capable terminal types. + // TERM often includes suffixes like "-256color", so we use substring matching. + constexpr std::string_view ColorTerms[] = { + "alacritty", + "ansi", + "color", + "console", + "cygwin", + "gnome", + "konsole", + "kterm", + "linux", + "msys", + "putty", + "rxvt", + "screen", + "tmux", + "vt100", + "vt102", + "xterm", + }; + for (std::string_view Candidate : ColorTerms) + { + if (TermView.find(Candidate) != std::string_view::npos) + { + return true; + } + } + } + +#if defined(_WIN32) + // Windows console supports ANSI color by default in modern versions + return true; +#else + // Unknown terminal — be conservative + return false; +#endif + } + + static bool ResolveColorMode(ColorMode Mode) + { + switch (Mode) + { + case ColorMode::On: + return true; + case ColorMode::Off: + return false; + case ColorMode::Auto: + default: + return IsColorTerminal(); + } + } void Log(const LogMessage& Msg) { @@ -107,7 +201,7 @@ struct AnsiColorStdoutSink::Impl MemoryBuffer Formatted; m_Formatter->Format(Msg, Formatted); - if (Msg.ColorRangeEnd > Msg.ColorRangeStart) + if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart) { // Print pre-color range fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File); @@ -148,10 +242,11 @@ struct AnsiColorStdoutSink::Impl private: std::mutex m_Mutex; std::unique_ptr m_Formatter; - FILE* m_File = stdout; + FILE* m_File = stdout; + bool m_UseColor = true; }; -AnsiColorStdoutSink::AnsiColorStdoutSink() : m_Impl(std::make_unique()) +AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique(Mode)) { } -- cgit v1.2.3 From f9d8cbcb3573b47b639b7bd73d3a4eed17653d71 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Mon, 9 Mar 2026 13:08:00 +0100 Subject: add fallback for zencache multirange (#816) * clean up BuildStorageResolveResult to allow capabilities * add check for multirange request capability * add MaxRangeCountPerRequest capabilities * project export tests * add InMemoryBuildStorageCache * progress and logging improvements * fix ElapsedSeconds calculations in fileremoteprojectstore.cpp * oplogs/builds test script --- src/zen/cmds/builds_cmd.cpp | 139 +- src/zen/cmds/projectstore_cmd.cpp | 26 +- src/zenhttp/include/zenhttp/formatters.h | 2 +- src/zenremotestore/builds/buildstoragecache.cpp | 209 ++- .../builds/buildstorageoperations.cpp | 391 +++-- src/zenremotestore/builds/buildstorageutil.cpp | 97 +- .../include/zenremotestore/builds/buildstorage.h | 2 - .../zenremotestore/builds/buildstoragecache.h | 10 +- .../zenremotestore/builds/buildstorageutil.h | 36 +- .../include/zenremotestore/chunking/chunkblock.h | 1 + .../include/zenremotestore/jupiter/jupiterhost.h | 3 +- .../projectstore/buildsremoteprojectstore.h | 22 +- .../projectstore/remoteprojectstore.h | 68 +- src/zenremotestore/jupiter/jupiterhost.cpp | 11 +- .../projectstore/buildsremoteprojectstore.cpp | 314 +--- .../projectstore/fileremoteprojectstore.cpp | 301 +++- .../projectstore/jupiterremoteprojectstore.cpp | 81 +- .../projectstore/projectstoreoperations.cpp | 14 +- .../projectstore/remoteprojectstore.cpp | 1694 ++++++++++++++------ .../projectstore/zenremoteprojectstore.cpp | 164 +- .../storage/buildstore/httpbuildstore.cpp | 13 + src/zenserver/storage/buildstore/httpbuildstore.h | 2 + .../storage/projectstore/httpprojectstore.cpp | 263 ++- 23 files changed, 2426 insertions(+), 1437 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index e5cbafbea..b4b4df7c9 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -1594,7 +1594,7 @@ namespace builds_impl { } } } - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { if (SB.Size() > 0) { @@ -1649,9 +1649,9 @@ namespace builds_impl { } if (Options.PrimeCacheOnly) { - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { - Storage.BuildCacheStorage->Flush(5000, [](intptr_t Remaining) { + Storage.CacheStorage->Flush(5000, [](intptr_t Remaining) { if (!IsQuiet) { if (Remaining == 0) @@ -2826,47 +2826,47 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) BuildStorageResolveResult ResolveRes = ResolveBuildStorage(*Output, ClientSettings, m_Host, m_OverrideHost, m_ZenCacheHost, ZenCacheResolveMode::All, m_Verbose); - if (!ResolveRes.HostUrl.empty()) + if (!ResolveRes.Cloud.Address.empty()) { - ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2; + ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2; Result.BuildStorageHttp = - std::make_unique(ResolveRes.HostUrl, ClientSettings, []() { return AbortFlag.load(); }); + std::make_unique(ResolveRes.Cloud.Address, ClientSettings, []() { return AbortFlag.load(); }); - Result.BuildStorage = CreateJupiterBuildStorage(Log(), + Result.BuildStorage = CreateJupiterBuildStorage(Log(), *Result.BuildStorageHttp, StorageStats, m_Namespace, m_Bucket, m_AllowRedirect, TempPath / "storage"); - Result.StorageName = ResolveRes.HostName; + Result.BuildStorageHost = ResolveRes.Cloud; - uint64_t HostLatencyNs = ResolveRes.HostLatencySec >= 0 ? uint64_t(ResolveRes.HostLatencySec * 1000000000.0) : 0; + uint64_t HostLatencyNs = ResolveRes.Cloud.LatencySec >= 0 ? uint64_t(ResolveRes.Cloud.LatencySec * 1000000000.0) : 0; - StorageDescription = fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'. Latency: {}", - ResolveRes.HostName, - (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), - Result.BuildStorageHttp->GetSessionId(), - m_Namespace, - m_Bucket, - NiceLatencyNs(HostLatencyNs)); - Result.BuildStorageLatencySec = ResolveRes.HostLatencySec; + StorageDescription = + fmt::format("Cloud {}{}. SessionId: '{}'. Namespace '{}', Bucket '{}'. Latency: {}", + ResolveRes.Cloud.Name, + (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address), + Result.BuildStorageHttp->GetSessionId(), + m_Namespace, + m_Bucket, + NiceLatencyNs(HostLatencyNs)); - if (!ResolveRes.CacheUrl.empty()) + if (!ResolveRes.Cache.Address.empty()) { Result.CacheHttp = std::make_unique( - ResolveRes.CacheUrl, + ResolveRes.Cache.Address, HttpClientSettings{ .LogCategory = "httpcacheclient", .ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, + .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2, .AllowResume = true, .RetryCount = 0, .Verbose = m_VerboseHttp, .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, []() { return AbortFlag.load(); }); - Result.BuildCacheStorage = + Result.CacheStorage = CreateZenBuildStorageCache(*Result.CacheHttp, StorageCacheStats, m_Namespace, @@ -2874,19 +2874,17 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) TempPath / "zencache", BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) : GetTinyWorkerPool(EWorkloadType::Background)); - Result.CacheName = ResolveRes.CacheName; + Result.CacheHost = ResolveRes.Cache; - uint64_t CacheLatencyNs = ResolveRes.CacheLatencySec >= 0 ? uint64_t(ResolveRes.CacheLatencySec * 1000000000.0) : 0; + uint64_t CacheLatencyNs = ResolveRes.Cache.LatencySec >= 0 ? uint64_t(ResolveRes.Cache.LatencySec * 1000000000.0) : 0; CacheDescription = fmt::format("Zen {}{}. SessionId: '{}'. Latency: {}", - ResolveRes.CacheName, - (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl), + ResolveRes.Cache.Name, + (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address), Result.CacheHttp->GetSessionId(), NiceLatencyNs(CacheLatencyNs)); - Result.CacheLatencySec = ResolveRes.CacheLatencySec; - if (!m_Namespace.empty()) { CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); @@ -2902,41 +2900,56 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { StorageDescription = fmt::format("folder {}", m_StoragePath); Result.BuildStorage = CreateFileBuildStorage(m_StoragePath, StorageStats, false, DefaultLatency, DefaultDelayPerKBSec); - Result.StorageName = fmt::format("Disk {}", m_StoragePath.stem()); + + Result.BuildStorageHost = BuildStorageResolveResult::Host{.Address = m_StoragePath.generic_string(), + .Name = "Disk", + .LatencySec = 1.0 / 100000, // 1 us + .Caps = {.MaxRangeCountPerRequest = 2048u}}; if (!m_ZenCacheHost.empty()) { - Result.CacheHttp = std::make_unique( - m_ZenCacheHost, - HttpClientSettings{ - .LogCategory = "httpcacheclient", - .ConnectTimeout = std::chrono::milliseconds{3000}, - .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = m_AssumeHttp2, - .AllowResume = true, - .RetryCount = 0, - .Verbose = m_VerboseHttp, - .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, - []() { return AbortFlag.load(); }); - Result.BuildCacheStorage = - CreateZenBuildStorageCache(*Result.CacheHttp, - StorageCacheStats, - m_Namespace, - m_Bucket, - TempPath / "zencache", - BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) - : GetTinyWorkerPool(EWorkloadType::Background)); - Result.CacheName = m_ZenCacheHost; - - CacheDescription = fmt::format("Zen {}{}. SessionId: '{}'", Result.CacheName, "", Result.CacheHttp->GetSessionId()); - ; - if (!m_Namespace.empty()) - { - CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); - } - if (!m_Bucket.empty()) + ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(m_ZenCacheHost, m_AssumeHttp2, m_VerboseHttp); + + if (TestResult.Success) { - CacheDescription += fmt::format(" Bucket '{}'", m_Bucket); + Result.CacheHttp = std::make_unique( + m_ZenCacheHost, + HttpClientSettings{ + .LogCategory = "httpcacheclient", + .ConnectTimeout = std::chrono::milliseconds{3000}, + .Timeout = std::chrono::milliseconds{30000}, + .AssumeHttp2 = m_AssumeHttp2, + .AllowResume = true, + .RetryCount = 0, + .Verbose = m_VerboseHttp, + .MaximumInMemoryDownloadSize = GetMaxMemoryBufferSize(DefaultMaxChunkBlockSize, m_BoostWorkerMemory)}, + []() { return AbortFlag.load(); }); + + Result.CacheStorage = + CreateZenBuildStorageCache(*Result.CacheHttp, + StorageCacheStats, + m_Namespace, + m_Bucket, + TempPath / "zencache", + BoostCacheBackgroundWorkerPool ? GetSmallWorkerPool(EWorkloadType::Background) + : GetTinyWorkerPool(EWorkloadType::Background)); + Result.CacheHost = + BuildStorageResolveResult::Host{.Address = m_ZenCacheHost, + .Name = m_ZenCacheHost, + .AssumeHttp2 = m_AssumeHttp2, + .LatencySec = TestResult.LatencySeconds, + .Caps = {.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest}}; + + CacheDescription = fmt::format("Zen {}. SessionId: '{}'", Result.CacheHost.Name, Result.CacheHttp->GetSessionId()); + + if (!m_Namespace.empty()) + { + CacheDescription += fmt::format(". Namespace '{}'", m_Namespace); + } + if (!m_Bucket.empty()) + { + CacheDescription += fmt::format(" Bucket '{}'", m_Bucket); + } } } } @@ -2948,7 +2961,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!IsQuiet) { ZEN_CONSOLE("Remote: {}", StorageDescription); - if (!Result.CacheName.empty()) + if (!Result.CacheHost.Name.empty()) { ZEN_CONSOLE("Cache : {}", CacheDescription); } @@ -3489,7 +3502,7 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) "Requests: {}\n" "Avg Request Time: {}\n" "Avg I/O Time: {}", - Storage.StorageName, + Storage.BuildStorageHost.Name, NiceBytes(StorageStats.TotalBytesRead.load()), NiceBytes(StorageStats.TotalBytesWritten.load()), StorageStats.TotalRequestCount.load(), @@ -3810,12 +3823,12 @@ BuildsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (!IsQuiet) { - if (Storage.BuildCacheStorage) + if (Storage.CacheStorage) { - ZEN_CONSOLE("Uploaded {} ({}) blobs", + ZEN_CONSOLE("Uploaded {} ({}) blobs to {}", StorageCacheStats.PutBlobCount.load(), NiceBytes(StorageCacheStats.PutBlobByteCount), - Storage.CacheName); + Storage.CacheHost.Name); } } diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index dfc6c1650..5ff591b54 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -2602,38 +2602,37 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a StorageInstance Storage; - ClientSettings.AssumeHttp2 = ResolveRes.HostAssumeHttp2; + ClientSettings.AssumeHttp2 = ResolveRes.Cloud.AssumeHttp2; ClientSettings.MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u; - Storage.BuildStorageHttp = std::make_unique(ResolveRes.HostUrl, ClientSettings); - Storage.BuildStorageLatencySec = ResolveRes.HostLatencySec; + Storage.BuildStorageHttp = std::make_unique(ResolveRes.Cloud.Address, ClientSettings); + Storage.BuildStorageHost = ResolveRes.Cloud; BuildStorageCache::Statistics StorageCacheStats; std::atomic AbortFlag(false); - if (!ResolveRes.CacheUrl.empty()) + if (!ResolveRes.Cache.Address.empty()) { Storage.CacheHttp = std::make_unique( - ResolveRes.CacheUrl, + ResolveRes.Cache.Address, HttpClientSettings{ .LogCategory = "httpcacheclient", .ConnectTimeout = std::chrono::milliseconds{3000}, .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, + .AssumeHttp2 = ResolveRes.Cache.AssumeHttp2, .AllowResume = true, .RetryCount = 0, .MaximumInMemoryDownloadSize = m_BoostWorkerMemory ? RemoteStoreOptions::DefaultMaxBlockSize : 1024u * 1024u}, [&AbortFlag]() { return AbortFlag.load(); }); - Storage.CacheName = ResolveRes.CacheName; - Storage.CacheLatencySec = ResolveRes.CacheLatencySec; + Storage.CacheHost = ResolveRes.Cache; } if (!m_Quiet) { std::string StorageDescription = fmt::format("Cloud {}{}. SessionId {}. Namespace '{}', Bucket '{}'", - ResolveRes.HostName, - (ResolveRes.HostUrl == ResolveRes.HostName) ? "" : fmt::format(" {}", ResolveRes.HostUrl), + ResolveRes.Cloud.Name, + (ResolveRes.Cloud.Address == ResolveRes.Cloud.Name) ? "" : fmt::format(" {}", ResolveRes.Cloud.Address), Storage.BuildStorageHttp->GetSessionId(), m_Namespace, m_Bucket); @@ -2644,8 +2643,8 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a { std::string CacheDescription = fmt::format("Zen {}{}. SessionId {}. Namespace '{}', Bucket '{}'", - ResolveRes.CacheName, - (ResolveRes.CacheUrl == ResolveRes.CacheName) ? "" : fmt::format(" {}", ResolveRes.CacheUrl), + ResolveRes.Cache.Name, + (ResolveRes.Cache.Address == ResolveRes.Cache.Name) ? "" : fmt::format(" {}", ResolveRes.Cache.Address), Storage.CacheHttp->GetSessionId(), m_Namespace, m_Bucket); @@ -2661,11 +2660,10 @@ OplogDownloadCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a Storage.BuildStorage = CreateJupiterBuildStorage(Log(), *Storage.BuildStorageHttp, StorageStats, m_Namespace, m_Bucket, m_AllowRedirect, StorageTempPath); - Storage.StorageName = ResolveRes.HostName; if (Storage.CacheHttp) { - Storage.BuildCacheStorage = CreateZenBuildStorageCache( + Storage.CacheStorage = CreateZenBuildStorageCache( *Storage.CacheHttp, StorageCacheStats, m_Namespace, diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index addb00cb8..57ab01158 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -73,7 +73,7 @@ struct fmt::formatter if (Response.IsSuccess()) { return fmt::format_to(Ctx.out(), - "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}", ToString(Response.StatusCode), Response.UploadedBytes, Response.DownloadedBytes, diff --git a/src/zenremotestore/builds/buildstoragecache.cpp b/src/zenremotestore/builds/buildstoragecache.cpp index 53d33bd7e..00765903d 100644 --- a/src/zenremotestore/builds/buildstoragecache.cpp +++ b/src/zenremotestore/builds/buildstoragecache.cpp @@ -528,6 +528,192 @@ CreateZenBuildStorageCache(HttpClient& HttpClient, return std::make_unique(HttpClient, Stats, Namespace, Bucket, TempFolderPath, BackgroundWorkerPool); } +#if ZEN_WITH_TESTS + +class InMemoryBuildStorageCache : public BuildStorageCache +{ +public: + // MaxRangeSupported == 0 : no range requests are accepted, always return full blob + // MaxRangeSupported == 1 : single range is supported, multi range returns full blob + // MaxRangeSupported > 1 : multirange is supported up to MaxRangeSupported, more ranges returns empty blob (bad request) + explicit InMemoryBuildStorageCache(uint64_t MaxRangeSupported, + BuildStorageCache::Statistics& Stats, + double LatencySec = 0.0, + double DelayPerKBSec = 0.0) + : m_MaxRangeSupported(MaxRangeSupported) + , m_Stats(Stats) + , m_LatencySec(LatencySec) + , m_DelayPerKBSec(DelayPerKBSec) + { + } + void PutBuildBlob(const Oid&, const IoHash& RawHash, ZenContentType, const CompositeBuffer& Payload) override + { + IoBuffer Buf = Payload.Flatten().AsIoBuffer(); + Buf.MakeOwned(); + const uint64_t SentBytes = Buf.Size(); + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + { + std::lock_guard Lock(m_Mutex); + m_Entries[RawHash] = std::move(Buf); + } + m_Stats.PutBlobCount.fetch_add(1); + m_Stats.PutBlobByteCount.fetch_add(SentBytes); + } + + IoBuffer GetBuildBlob(const Oid&, const IoHash& RawHash, uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) override + { + uint64_t SentBytes = 0; + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + IoBuffer FullPayload; + { + std::lock_guard Lock(m_Mutex); + auto It = m_Entries.find(RawHash); + if (It == m_Entries.end()) + { + return {}; + } + FullPayload = It->second; + } + + if (RangeOffset != 0 || RangeBytes != (uint64_t)-1) + { + if (m_MaxRangeSupported == 0) + { + ReceivedBytes = FullPayload.Size(); + return FullPayload; + } + else + { + ReceivedBytes = (RangeBytes == (uint64_t)-1) ? FullPayload.Size() - RangeOffset : RangeBytes; + return IoBuffer(FullPayload, RangeOffset, RangeBytes); + } + } + else + { + ReceivedBytes = FullPayload.Size(); + return FullPayload; + } + } + + BuildBlobRanges GetBuildBlobRanges(const Oid&, const IoHash& RawHash, std::span> Ranges) override + { + ZEN_ASSERT(!Ranges.empty()); + uint64_t SentBytes = 0; + uint64_t ReceivedBytes = 0; + SimulateLatency(SentBytes, 0); + auto _ = MakeGuard([&]() { SimulateLatency(0, ReceivedBytes); }); + Stopwatch ExecutionTimer; + auto __ = MakeGuard([&]() { AddStatistic(ExecutionTimer.GetElapsedTimeUs(), ReceivedBytes, SentBytes); }); + if (m_MaxRangeSupported > 1 && Ranges.size() > m_MaxRangeSupported) + { + return {}; + } + IoBuffer FullPayload; + { + std::lock_guard Lock(m_Mutex); + auto It = m_Entries.find(RawHash); + if (It == m_Entries.end()) + { + return {}; + } + FullPayload = It->second; + } + + if (Ranges.size() > m_MaxRangeSupported) + { + // An empty Ranges signals to the caller: "full buffer given, use it for all requested ranges". + ReceivedBytes = FullPayload.Size(); + return {.PayloadBuffer = FullPayload}; + } + else + { + uint64_t PayloadStart = Ranges.front().first; + uint64_t PayloadSize = Ranges.back().first + Ranges.back().second - PayloadStart; + IoBuffer RangeBuffer = IoBuffer(FullPayload, PayloadStart, PayloadSize); + std::vector> PayloadRanges; + PayloadRanges.reserve(Ranges.size()); + for (const std::pair& Range : Ranges) + { + PayloadRanges.push_back(std::make_pair(Range.first - PayloadStart, Range.second)); + } + ReceivedBytes = PayloadSize; + return {.PayloadBuffer = RangeBuffer, .Ranges = std::move(PayloadRanges)}; + } + } + + void PutBlobMetadatas(const Oid&, std::span, std::span) override {} + + std::vector GetBlobMetadatas(const Oid&, std::span Hashes) override + { + return std::vector(Hashes.size()); + } + + std::vector BlobsExists(const Oid&, std::span Hashes) override + { + std::lock_guard Lock(m_Mutex); + std::vector Result; + Result.reserve(Hashes.size()); + for (const IoHash& Hash : Hashes) + { + auto It = m_Entries.find(Hash); + Result.push_back({.HasBody = (It != m_Entries.end() && It->second)}); + } + return Result; + } + + void Flush(int32_t, std::function&&) override {} + +private: + void AddStatistic(uint64_t ElapsedTimeUs, uint64_t ReceivedBytes, uint64_t SentBytes) + { + m_Stats.TotalBytesWritten += SentBytes; + m_Stats.TotalBytesRead += ReceivedBytes; + m_Stats.TotalExecutionTimeUs += ElapsedTimeUs; + m_Stats.TotalRequestCount++; + SetAtomicMax(m_Stats.PeakSentBytes, SentBytes); + SetAtomicMax(m_Stats.PeakReceivedBytes, ReceivedBytes); + if (ElapsedTimeUs > 0) + { + SetAtomicMax(m_Stats.PeakBytesPerSec, (ReceivedBytes + SentBytes) * 1000000 / ElapsedTimeUs); + } + } + + void SimulateLatency(uint64_t SendBytes, uint64_t ReceiveBytes) + { + double SleepSec = m_LatencySec; + if (m_DelayPerKBSec > 0.0) + { + SleepSec += m_DelayPerKBSec * (double(SendBytes + ReceiveBytes) / 1024u); + } + if (SleepSec > 0) + { + Sleep(int(SleepSec * 1000)); + } + } + + uint64_t m_MaxRangeSupported = 0; + BuildStorageCache::Statistics& m_Stats; + const double m_LatencySec = 0.0; + const double m_DelayPerKBSec = 0.0; + std::mutex m_Mutex; + std::unordered_map m_Entries; +}; + +std::unique_ptr +CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported, BuildStorageCache::Statistics& Stats, double LatencySec, double DelayPerKBSec) +{ + return std::make_unique(MaxRangeSupported, Stats, LatencySec, DelayPerKBSec); +} +#endif // ZEN_WITH_TESTS + ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose) { @@ -542,15 +728,28 @@ TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const boo HttpClient::Response TestResponse = TestHttpClient.Get("/status/builds"); if (TestResponse.IsSuccess()) { - LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health"); + uint64_t MaxRangeCountPerRequest = 1; + CbObject StatusResponse = TestResponse.AsObject(); + if (StatusResponse["ok"].AsBool()) + { + MaxRangeCountPerRequest = StatusResponse["capabilities"].AsObjectView()["maxrangecountperrequest"].AsUInt64(1); + + LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health"); + + if (!LatencyResult.Success) + { + return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + } - if (!LatencyResult.Success) + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest}; + } + else { - return {.Success = false, .FailureReason = LatencyResult.FailureReason}; + return {.Success = false, + .FailureReason = fmt::format("ZenCache endpoint {}/status/builds did not respond with \"ok\"", BaseUrl)}; } - return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds}; } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; -}; +} } // namespace zen diff --git a/src/zenremotestore/builds/buildstorageoperations.cpp b/src/zenremotestore/builds/buildstorageoperations.cpp index 43a4937f0..f4b167b73 100644 --- a/src/zenremotestore/builds/buildstorageoperations.cpp +++ b/src/zenremotestore/builds/buildstorageoperations.cpp @@ -887,11 +887,12 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) ChunkBlockAnalyser BlockAnalyser( m_LogOutput, m_BlockDescriptions, - ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, - .IsVerbose = m_Options.IsVerbose, - .HostLatencySec = m_Storage.BuildStorageLatencySec, - .HostHighSpeedLatencySec = m_Storage.CacheLatencySec, - .HostMaxRangeCountPerRequest = BuildStorageBase::MaxRangeCountPerRequest}); + ChunkBlockAnalyser::Options{.IsQuiet = m_Options.IsQuiet, + .IsVerbose = m_Options.IsVerbose, + .HostLatencySec = m_Storage.BuildStorageHost.LatencySec, + .HostHighSpeedLatencySec = m_Storage.CacheHost.LatencySec, + .HostMaxRangeCountPerRequest = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest, + .HostHighSpeedMaxRangeCountPerRequest = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest}); std::vector NeededBlocks = BlockAnalyser.GetNeeded( m_RemoteLookup.ChunkHashToChunkIndex, @@ -974,7 +975,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } } - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { ZEN_TRACE_CPU("BlobCacheExistCheck"); Stopwatch Timer; @@ -993,7 +994,7 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } const std::vector CacheExistsResult = - m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); + m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); if (CacheExistsResult.size() == BlobHashes.size()) { @@ -1018,32 +1019,50 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) } std::vector BlockPartialDownloadModes; + if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Off) { BlockPartialDownloadModes.resize(m_BlockDescriptions.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } else { + ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + + switch (m_Options.PartialBlockRequestMode) + { + case EPartialBlockRequestMode::Off: + break; + case EPartialBlockRequestMode::ZenCacheOnly: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + break; + case EPartialBlockRequestMode::Mixed: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + case EPartialBlockRequestMode::All: + CachePartialDownloadMode = m_Storage.CacheHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + default: + ZEN_ASSERT(false); + break; + } + BlockPartialDownloadModes.reserve(m_BlockDescriptions.size()); for (uint32_t BlockIndex = 0; BlockIndex < m_BlockDescriptions.size(); BlockIndex++) { const bool BlockExistInCache = ExistsResult.ExistingBlobs.contains(m_BlockDescriptions[BlockIndex].BlockHash); - if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::All) - { - BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); - } - else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) - { - BlockPartialDownloadModes.push_back(BlockExistInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact - : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); - } - else if (m_Options.PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) - { - BlockPartialDownloadModes.push_back(BlockExistInCache - ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); - } + BlockPartialDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); } } @@ -1527,20 +1546,20 @@ BuildsOperationUpdateFolder::Execute(FolderContent& OutLocalFolderState) IoBuffer BlockBuffer; const bool ExistsInCache = - m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); + m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); if (ExistsInCache) { - BlockBuffer = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); + BlockBuffer = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); } if (!BlockBuffer) { BlockBuffer = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash); - if (BlockBuffer && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (BlockBuffer && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BlockBuffer))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(BlockBuffer))); } } if (!BlockBuffer) @@ -3103,10 +3122,10 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde const IoHash& ChunkHash = m_RemoteContent.ChunkedContent.ChunkHashes[RemoteChunkIndex]; // FilteredDownloadedBytesPerSecond.Start(); IoBuffer BuildBlob; - const bool ExistsInCache = m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash); + const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(ChunkHash); if (ExistsInCache) { - BuildBlob = m_Storage.BuildCacheStorage->GetBuildBlob(m_BuildId, ChunkHash); + BuildBlob = m_Storage.CacheStorage->GetBuildBlob(m_BuildId, ChunkHash); } if (BuildBlob) { @@ -3134,12 +3153,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde m_DownloadStats.DownloadedChunkCount++; m_DownloadStats.RequestsCompleteCount++; - if (Payload && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (Payload && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); } OnDownloaded(std::move(Payload)); @@ -3148,12 +3167,12 @@ BuildsOperationUpdateFolder::DownloadBuildBlob(uint32_t RemoteChunkInde else { BuildBlob = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, ChunkHash); - if (BuildBlob && m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (BuildBlob && m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - ChunkHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(BuildBlob))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + ChunkHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(BuildBlob))); } if (!BuildBlob) { @@ -3273,34 +3292,7 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); } - if (m_Storage.BuildCacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash)) - { - BuildStorageCache::BuildBlobRanges RangeBuffers = - m_Storage.BuildCacheStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, Ranges); - if (RangeBuffers.PayloadBuffer) - { - if (!m_AbortFlag) - { - if (RangeBuffers.Ranges.size() != Ranges.size()) - { - throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", - Ranges.size(), - BlockDescription.BlockHash, - RangeBuffers.Ranges.size())); - } - - std::vector> BlockOffsetAndLengths = std::move(RangeBuffers.Ranges); - ProcessDownload(BlockDescription, - std::move(RangeBuffers.PayloadBuffer), - BlockRangeStartIndex, - BlockOffsetAndLengths, - OnDownloaded); - } - return; - } - } - - const size_t MaxRangesPerRequestToJupiter = BuildStorageBase::MaxRangeCountPerRequest; + const bool ExistsInCache = m_Storage.CacheStorage && ExistsResult.ExistingBlobs.contains(BlockDescription.BlockHash); size_t SubBlockRangeCount = BlockRangeCount; size_t SubRangeCountComplete = 0; @@ -3311,30 +3303,101 @@ BuildsOperationUpdateFolder::DownloadPartialBlock( { break; } - size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, MaxRangesPerRequestToJupiter); + + // First try to get subrange from cache. + // If not successful, try to get the ranges from the build store and adapt SubRangeCount... + size_t SubRangeStartIndex = BlockRangeStartIndex + SubRangeCountComplete; + if (ExistsInCache) + { + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.CacheHost.Caps.MaxRangeCountPerRequest); + + if (SubRangeCount == 1) + { + // Legacy single-range path, prefer that for max compatibility + + const std::pair SubRange = RangesSpan[SubRangeCountComplete]; + IoBuffer PayloadBuffer = + m_Storage.CacheStorage->GetBuildBlob(m_BuildId, BlockDescription.BlockHash, SubRange.first, SubRange.second); + if (m_AbortFlag) + { + break; + } + if (PayloadBuffer) + { + ProcessDownload(BlockDescription, + std::move(PayloadBuffer), + SubRangeStartIndex, + std::vector>{std::make_pair(0u, SubRange.second)}, + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + else + { + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + BuildStorageCache::BuildBlobRanges RangeBuffers = + m_Storage.CacheStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); + if (m_AbortFlag) + { + break; + } + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) + { + SubRangeCount = Ranges.size() - SubRangeCountComplete; + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount), + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; + } + else if (RangeBuffers.Ranges.size() == SubRangeCount) + { + ProcessDownload(BlockDescription, + std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangeBuffers.Ranges, + OnDownloaded); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + } + } + + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, m_Storage.BuildStorageHost.Caps.MaxRangeCountPerRequest); auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); BuildStorageBase::BuildBlobRanges RangeBuffers = m_Storage.BuildStorage->GetBuildBlobRanges(m_BuildId, BlockDescription.BlockHash, SubRanges); + if (m_AbortFlag) + { + break; + } if (RangeBuffers.PayloadBuffer) { - if (m_AbortFlag) - { - break; - } if (RangeBuffers.Ranges.empty()) { // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 // Upload to cache (if enabled) and use the whole payload for the remaining ranges - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlockDescription.BlockHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(std::vector{RangeBuffers.PayloadBuffer})); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector{RangeBuffers.PayloadBuffer})); + if (m_AbortFlag) + { + break; + } } SubRangeCount = Ranges.size() - SubRangeCountComplete; @@ -4932,12 +4995,12 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& const IoHash& BlockHash = OutBlocks.BlockDescriptions[BlockIndex].BlockHash; const uint64_t CompressedBlockSize = Payload.GetCompressedSize(); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlockHash, - ZenContentType::kCompressedBinary, - Payload.GetCompressed()); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlockHash, + ZenContentType::kCompressedBinary, + Payload.GetCompressed()); } m_Storage.BuildStorage->PutBuildBlob(m_BuildId, @@ -4955,11 +5018,11 @@ BuildsOperationUploadFolder::GenerateBuildBlocks(const ChunkedFolderContent& OutBlocks.BlockDescriptions[BlockIndex].ChunkRawHashes.size()); } - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector({BlockHash}), - std::vector({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector({BlockHash}), + std::vector({BlockMetaData})); } bool MetadataSucceeded = @@ -5803,11 +5866,11 @@ BuildsOperationUploadFolder::UploadBuildPart(ChunkingController& ChunkController { const CbObject BlockMetaData = BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector({BlockHash}), - std::vector({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector({BlockHash}), + std::vector({BlockMetaData})); } bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); if (MetadataSucceeded) @@ -6001,9 +6064,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co const CbObject BlockMetaData = BuildChunkBlockDescription(NewBlocks.BlockDescriptions[BlockIndex], NewBlocks.BlockMetaDatas[BlockIndex]); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); } m_Storage.BuildStorage->PutBuildBlob(m_BuildId, BlockHash, ZenContentType::kCompressedBinary, Payload); if (m_Options.IsVerbose) @@ -6017,11 +6080,11 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co UploadedBlockSize += PayloadSize; TempUploadStats.BlocksBytes += PayloadSize; - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBlobMetadatas(m_BuildId, - std::vector({BlockHash}), - std::vector({BlockMetaData})); + m_Storage.CacheStorage->PutBlobMetadatas(m_BuildId, + std::vector({BlockHash}), + std::vector({BlockMetaData})); } bool MetadataSucceeded = m_Storage.BuildStorage->PutBlockMetadata(m_BuildId, BlockHash, BlockMetaData); if (MetadataSucceeded) @@ -6084,9 +6147,9 @@ BuildsOperationUploadFolder::UploadPartBlobs(const ChunkedFolderContent& Co const uint64_t PayloadSize = Payload.GetSize(); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, RawHash, ZenContentType::kCompressedBinary, Payload); } if (PayloadSize >= LargeAttachmentSize) @@ -6830,14 +6893,14 @@ BuildsOperationPrimeCache::Execute() std::vector BlobsToDownload; BlobsToDownload.reserve(BuildBlobs.size()); - if (m_Storage.BuildCacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload) + if (m_Storage.CacheStorage && !BuildBlobs.empty() && !m_Options.ForceUpload) { ZEN_TRACE_CPU("BlobCacheExistCheck"); Stopwatch Timer; const std::vector BlobHashes(BuildBlobs.begin(), BuildBlobs.end()); const std::vector CacheExistsResult = - m_Storage.BuildCacheStorage->BlobsExists(m_BuildId, BlobHashes); + m_Storage.CacheStorage->BlobsExists(m_BuildId, BlobHashes); if (CacheExistsResult.size() == BlobHashes.size()) { @@ -6884,33 +6947,33 @@ BuildsOperationPrimeCache::Execute() for (size_t BlobIndex = 0; BlobIndex < BlobCount; BlobIndex++) { - Work.ScheduleWork( - m_NetworkPool, - [this, - &Work, - &BlobsToDownload, - BlobCount, - &LooseChunkRawSizes, - &CompletedDownloadCount, - &FilteredDownloadedBytesPerSecond, - &MultipartAttachmentCount, - BlobIndex](std::atomic&) { - if (!m_AbortFlag) - { - const IoHash& BlobHash = BlobsToDownload[BlobIndex]; + Work.ScheduleWork(m_NetworkPool, + [this, + &Work, + &BlobsToDownload, + BlobCount, + &LooseChunkRawSizes, + &CompletedDownloadCount, + &FilteredDownloadedBytesPerSecond, + &MultipartAttachmentCount, + BlobIndex](std::atomic&) { + if (!m_AbortFlag) + { + const IoHash& BlobHash = BlobsToDownload[BlobIndex]; - bool IsLargeBlob = false; + bool IsLargeBlob = false; - if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) - { - IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; - } + if (auto It = LooseChunkRawSizes.find(BlobHash); It != LooseChunkRawSizes.end()) + { + IsLargeBlob = It->second >= m_Options.LargeAttachmentSize; + } - FilteredDownloadedBytesPerSecond.Start(); + FilteredDownloadedBytesPerSecond.Start(); - if (IsLargeBlob) - { - DownloadLargeBlob(*m_Storage.BuildStorage, + if (IsLargeBlob) + { + DownloadLargeBlob( + *m_Storage.BuildStorage, m_TempPath, m_BuildId, BlobHash, @@ -6926,12 +6989,12 @@ BuildsOperationPrimeCache::Execute() if (!m_AbortFlag) { - if (Payload && m_Storage.BuildCacheStorage) + if (Payload && m_Storage.CacheStorage) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(Payload))); } } CompletedDownloadCount++; @@ -6940,32 +7003,32 @@ BuildsOperationPrimeCache::Execute() FilteredDownloadedBytesPerSecond.Stop(); } }); - } - else - { - IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); - m_DownloadStats.DownloadedBlockCount++; - m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); - m_DownloadStats.RequestsCompleteCount++; + } + else + { + IoBuffer Payload = m_Storage.BuildStorage->GetBuildBlob(m_BuildId, BlobHash); + m_DownloadStats.DownloadedBlockCount++; + m_DownloadStats.DownloadedBlockByteCount += Payload.GetSize(); + m_DownloadStats.RequestsCompleteCount++; - if (!m_AbortFlag) - { - if (Payload && m_Storage.BuildCacheStorage) - { - m_Storage.BuildCacheStorage->PutBuildBlob(m_BuildId, - BlobHash, - ZenContentType::kCompressedBinary, - CompositeBuffer(SharedBuffer(std::move(Payload)))); - } - } - CompletedDownloadCount++; - if (CompletedDownloadCount == BlobCount) - { - FilteredDownloadedBytesPerSecond.Stop(); - } - } - } - }); + if (!m_AbortFlag) + { + if (Payload && m_Storage.CacheStorage) + { + m_Storage.CacheStorage->PutBuildBlob(m_BuildId, + BlobHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(SharedBuffer(std::move(Payload)))); + } + } + CompletedDownloadCount++; + if (CompletedDownloadCount == BlobCount) + { + FilteredDownloadedBytesPerSecond.Stop(); + } + } + } + }); } Work.Wait(m_LogOutput.GetProgressUpdateDelayMS(), [&](bool IsAborted, bool IsPaused, std::ptrdiff_t PendingWork) { @@ -6977,10 +7040,10 @@ BuildsOperationPrimeCache::Execute() std::string DownloadRateString = (CompletedDownloadCount == BlobCount) ? "" : fmt::format(" {}bits/s", NiceNum(FilteredDownloadedBytesPerSecond.GetCurrent() * 8)); - std::string UploadDetails = m_Storage.BuildCacheStorage ? fmt::format(" {} ({}) uploaded.", - m_StorageCacheStats.PutBlobCount.load(), - NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) - : ""; + std::string UploadDetails = m_Storage.CacheStorage ? fmt::format(" {} ({}) uploaded.", + m_StorageCacheStats.PutBlobCount.load(), + NiceBytes(m_StorageCacheStats.PutBlobByteCount.load())) + : ""; std::string Details = fmt::format("{}/{} ({}{}) downloaded.{}", CompletedDownloadCount.load(), @@ -7005,13 +7068,13 @@ BuildsOperationPrimeCache::Execute() return; } - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { - m_Storage.BuildCacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { + m_Storage.CacheStorage->Flush(m_LogOutput.GetProgressUpdateDelayMS(), [this](intptr_t Remaining) -> bool { ZEN_UNUSED(Remaining); if (!m_Options.IsQuiet) { - ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheName); + ZEN_OPERATION_LOG_INFO(m_LogOutput, "Waiting for {} blobs to finish upload to '{}'", Remaining, m_Storage.CacheHost.Name); } return !m_AbortFlag; }); @@ -7221,7 +7284,7 @@ GetRemoteContent(OperationLogOutput& Output, bool AttemptFallback = false; OutBlockDescriptions = GetBlockDescriptions(Output, *Storage.BuildStorage, - Storage.BuildCacheStorage.get(), + Storage.CacheStorage.get(), BuildId, BlockRawHashes, AttemptFallback, diff --git a/src/zenremotestore/builds/buildstorageutil.cpp b/src/zenremotestore/builds/buildstorageutil.cpp index d65f18b9a..2ae726e29 100644 --- a/src/zenremotestore/builds/buildstorageutil.cpp +++ b/src/zenremotestore/builds/buildstorageutil.cpp @@ -63,13 +63,15 @@ ResolveBuildStorage(OperationLogOutput& Output, std::string HostUrl; std::string HostName; - double HostLatencySec = -1.0; + double HostLatencySec = -1.0; + uint64_t HostMaxRangeCountPerRequest = 1; std::string CacheUrl; std::string CacheName; - bool HostAssumeHttp2 = ClientSettings.AssumeHttp2; - bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2; - double CacheLatencySec = -1.0; + bool HostAssumeHttp2 = ClientSettings.AssumeHttp2; + bool CacheAssumeHttp2 = ClientSettings.AssumeHttp2; + double CacheLatencySec = -1.0; + uint64_t CacheMaxRangeCountPerRequest = 1; JupiterServerDiscovery DiscoveryResponse; const std::string_view DiscoveryHost = Host.empty() ? OverrideHost : Host; @@ -100,9 +102,10 @@ ResolveBuildStorage(OperationLogOutput& Output, { ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", OverrideHost); } - HostUrl = OverrideHost; - HostName = GetHostNameFromUrl(OverrideHost); - HostLatencySec = TestResult.LatencySeconds; + HostUrl = OverrideHost; + HostName = GetHostNameFromUrl(OverrideHost); + HostLatencySec = TestResult.LatencySeconds; + HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; } else { @@ -137,10 +140,11 @@ ResolveBuildStorage(OperationLogOutput& Output, ZEN_OPERATION_LOG_INFO(Output, "Server endpoint at '{}/api/v1/status/servers' succeeded", ServerEndpoint.BaseUrl); } - HostUrl = ServerEndpoint.BaseUrl; - HostAssumeHttp2 = ServerEndpoint.AssumeHttp2; - HostName = ServerEndpoint.Name; - HostLatencySec = TestResult.LatencySeconds; + HostUrl = ServerEndpoint.BaseUrl; + HostAssumeHttp2 = ServerEndpoint.AssumeHttp2; + HostName = ServerEndpoint.Name; + HostLatencySec = TestResult.LatencySeconds; + HostMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; break; } else @@ -184,10 +188,11 @@ ResolveBuildStorage(OperationLogOutput& Output, ZEN_OPERATION_LOG_INFO(Output, "Cache endpoint at '{}/status/builds' succeeded", CacheEndpoint.BaseUrl); } - CacheUrl = CacheEndpoint.BaseUrl; - CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2; - CacheName = CacheEndpoint.Name; - CacheLatencySec = TestResult.LatencySeconds; + CacheUrl = CacheEndpoint.BaseUrl; + CacheAssumeHttp2 = CacheEndpoint.AssumeHttp2; + CacheName = CacheEndpoint.Name; + CacheLatencySec = TestResult.LatencySeconds; + CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; break; } } @@ -225,9 +230,10 @@ ResolveBuildStorage(OperationLogOutput& Output, if (ZenCacheEndpointTestResult TestResult = TestZenCacheEndpoint(ZenCacheHost, /*AssumeHttp2*/ false, ClientSettings.Verbose); TestResult.Success) { - CacheUrl = ZenCacheHost; - CacheName = GetHostNameFromUrl(ZenCacheHost); - CacheLatencySec = TestResult.LatencySeconds; + CacheUrl = ZenCacheHost; + CacheName = GetHostNameFromUrl(ZenCacheHost); + CacheLatencySec = TestResult.LatencySeconds; + CacheMaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; } else { @@ -235,15 +241,34 @@ ResolveBuildStorage(OperationLogOutput& Output, } } - return BuildStorageResolveResult{.HostUrl = HostUrl, - .HostName = HostName, - .HostAssumeHttp2 = HostAssumeHttp2, - .HostLatencySec = HostLatencySec, + return BuildStorageResolveResult{ + .Cloud = {.Address = HostUrl, + .Name = HostName, + .AssumeHttp2 = HostAssumeHttp2, + .LatencySec = HostLatencySec, + .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = HostMaxRangeCountPerRequest}}, + .Cache = {.Address = CacheUrl, + .Name = CacheName, + .AssumeHttp2 = CacheAssumeHttp2, + .LatencySec = CacheLatencySec, + .Caps = BuildStorageResolveResult::Capabilities{.MaxRangeCountPerRequest = CacheMaxRangeCountPerRequest}}}; +} - .CacheUrl = CacheUrl, - .CacheName = CacheName, - .CacheAssumeHttp2 = CacheAssumeHttp2, - .CacheLatencySec = CacheLatencySec}; +std::vector +ParseBlockMetadatas(std::span BlockMetadatas) +{ + std::vector UnorderedList; + UnorderedList.reserve(BlockMetadatas.size()); + for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < BlockMetadatas.size(); CacheBlockMetadataIndex++) + { + const CbObject& CacheBlockMetadata = BlockMetadatas[CacheBlockMetadataIndex]; + ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata); + if (Description.BlockHash != IoHash::Zero) + { + UnorderedList.emplace_back(std::move(Description)); + } + } + return UnorderedList; } std::vector @@ -263,25 +288,15 @@ GetBlockDescriptions(OperationLogOutput& Output, if (OptionalCacheStorage && !BlockRawHashes.empty()) { std::vector CacheBlockMetadatas = OptionalCacheStorage->GetBlobMetadatas(BuildId, BlockRawHashes); - UnorderedList.reserve(CacheBlockMetadatas.size()); - for (size_t CacheBlockMetadataIndex = 0; CacheBlockMetadataIndex < CacheBlockMetadatas.size(); CacheBlockMetadataIndex++) + if (!CacheBlockMetadatas.empty()) { - const CbObject& CacheBlockMetadata = CacheBlockMetadatas[CacheBlockMetadataIndex]; - ChunkBlockDescription Description = ParseChunkBlockDescription(CacheBlockMetadata); - if (Description.BlockHash == IoHash::Zero) - { - ZEN_OPERATION_LOG_WARN(Output, "Unexpected/invalid block metadata received from remote cache, skipping block"); - } - else + UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas); + for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) { - UnorderedList.emplace_back(std::move(Description)); + const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); } } - for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) - { - const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; - BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); - } } if (UnorderedList.size() < BlockRawHashes.size()) diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h index ce3da41c1..da8437a58 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorage.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorage.h @@ -58,8 +58,6 @@ public: uint64_t RangeOffset = 0, uint64_t RangeBytes = (uint64_t)-1) = 0; - static constexpr size_t MaxRangeCountPerRequest = 128u; - struct BuildBlobRanges { IoBuffer PayloadBuffer; diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h index 67c93480b..24702df0f 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstoragecache.h @@ -69,11 +69,19 @@ std::unique_ptr CreateZenBuildStorageCache(HttpClient& H const std::filesystem::path& TempFolderPath, WorkerThreadPool& BackgroundWorkerPool); +#if ZEN_WITH_TESTS +std::unique_ptr CreateInMemoryBuildStorageCache(uint64_t MaxRangeSupported, + BuildStorageCache::Statistics& Stats, + double LatencySec = 0.0, + double DelayPerKBSec = 0.0); +#endif // ZEN_WITH_TESTS + struct ZenCacheEndpointTestResult { bool Success = false; std::string FailureReason; - double LatencySeconds = -1.0; + double LatencySeconds = -1.0; + uint64_t MaxRangeCountPerRequest = 1; }; ZenCacheEndpointTestResult TestZenCacheEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h index 764a24e61..7306188ca 100644 --- a/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h +++ b/src/zenremotestore/include/zenremotestore/builds/buildstorageutil.h @@ -14,15 +14,20 @@ class BuildStorageCache; struct BuildStorageResolveResult { - std::string HostUrl; - std::string HostName; - bool HostAssumeHttp2 = false; - double HostLatencySec = -1.0; - - std::string CacheUrl; - std::string CacheName; - bool CacheAssumeHttp2 = false; - double CacheLatencySec = -1.0; + struct Capabilities + { + uint64_t MaxRangeCountPerRequest = 1; + }; + struct Host + { + std::string Address; + std::string Name; + bool AssumeHttp2 = false; + double LatencySec = -1.0; + Capabilities Caps; + }; + Host Cloud; + Host Cache; }; enum class ZenCacheResolveMode @@ -52,14 +57,13 @@ std::vector GetBlockDescriptions(OperationLogOutput& Out struct StorageInstance { - std::unique_ptr BuildStorageHttp; - std::unique_ptr BuildStorage; - std::string StorageName; - double BuildStorageLatencySec = -1.0; + BuildStorageResolveResult::Host BuildStorageHost; + std::unique_ptr BuildStorageHttp; + std::unique_ptr BuildStorage; + + BuildStorageResolveResult::Host CacheHost; std::unique_ptr CacheHttp; - std::unique_ptr BuildCacheStorage; - std::string CacheName; - double CacheLatencySec = -1.0; + std::unique_ptr CacheStorage; }; } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h index 20b6fd371..c085f10e7 100644 --- a/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h +++ b/src/zenremotestore/include/zenremotestore/chunking/chunkblock.h @@ -31,6 +31,7 @@ struct ChunkBlockDescription : public ThinChunkBlockDescription std::vector ParseChunkBlockDescriptionList(const CbObjectView& BlocksObject); ChunkBlockDescription ParseChunkBlockDescription(const CbObjectView& BlockObject); +std::vector ParseBlockMetadatas(std::span BlockMetadatas); CbObject BuildChunkBlockDescription(const ChunkBlockDescription& Block, CbObjectView MetaData); ChunkBlockDescription GetChunkBlockDescription(const SharedBuffer& BlockPayload, const IoHash& RawHash); typedef std::function(const IoHash& RawHash)> FetchChunkFunc; diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h index 7bbf40dfa..bb41f9efc 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h @@ -28,7 +28,8 @@ struct JupiterEndpointTestResult { bool Success = false; std::string FailureReason; - double LatencySeconds = -1.0; + double LatencySeconds = -1.0; + uint64_t MaxRangeCountPerRequest = 1; }; JupiterEndpointTestResult TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpVerbose); diff --git a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h index 66dfcc62d..c058e1c1f 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/buildsremoteprojectstore.h @@ -2,6 +2,7 @@ #pragma once +#include #include namespace zen { @@ -10,9 +11,6 @@ class AuthMgr; struct BuildsRemoteStoreOptions : RemoteStoreOptions { - std::string Host; - std::string OverrideHost; - std::string ZenHost; std::string Namespace; std::string Bucket; Oid BuildId; @@ -22,20 +20,16 @@ struct BuildsRemoteStoreOptions : RemoteStoreOptions std::filesystem::path OidcExePath; bool ForceDisableBlocks = false; bool ForceDisableTempBlocks = false; - bool AssumeHttp2 = false; - bool PopulateCache = true; IoBuffer MetaData; size_t MaximumInMemoryDownloadSize = 1024u * 1024u; }; -std::shared_ptr CreateJupiterBuildsRemoteStore(LoggerRef InLog, - const BuildsRemoteStoreOptions& Options, - const std::filesystem::path& TempFilePath, - bool Quiet, - bool Unattended, - bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool, - double& OutHostLatencySec, - double& OutCacheLatencySec); +struct BuildStorageResolveResult; + +std::shared_ptr CreateJupiterBuildsRemoteStore(LoggerRef InLog, + const BuildStorageResolveResult& ResolveResult, + std::function&& TokenProvider, + const BuildsRemoteStoreOptions& Options, + const std::filesystem::path& TempFilePath); } // namespace zen diff --git a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h index 42786d0f2..084d975a2 100644 --- a/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h +++ b/src/zenremotestore/include/zenremotestore/projectstore/remoteprojectstore.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -79,11 +80,6 @@ public: std::vector Blocks; }; - struct AttachmentExistsInCacheResult : public Result - { - std::vector HasBody; - }; - struct LoadAttachmentRangesResult : public Result { IoBuffer Bytes; @@ -128,28 +124,17 @@ public: virtual FinalizeResult FinalizeContainer(const IoHash& RawHash) = 0; virtual SaveAttachmentsResult SaveAttachments(const std::vector& Payloads) = 0; - virtual LoadContainerResult LoadContainer() = 0; - virtual GetKnownBlocksResult GetKnownBlocks() = 0; - virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) = 0; - virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) = 0; - - enum ESourceMode - { - kAny, - kCacheOnly, - kHostOnly - }; + virtual LoadContainerResult LoadContainer() = 0; + virtual GetKnownBlocksResult GetKnownBlocks() = 0; + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) = 0; - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode = ESourceMode::kAny) = 0; - - static constexpr size_t MaxRangeCountPerRequest = 128u; + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, - std::span> Ranges, - ESourceMode SourceMode = ESourceMode::kAny) = 0; - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode = ESourceMode::kAny) = 0; - - virtual void Flush() = 0; + std::span> Ranges) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) = 0; }; struct RemoteStoreOptions @@ -211,18 +196,29 @@ RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, bool IgnoreMissingAttachments, JobContext* OptionalContext); -RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - EPartialBlockRequestMode PartialBlockRequestMode, - double HostLatencySec, - double CacheLatencySec, - JobContext* OptionalContext); +struct LoadOplogContext +{ + CidStore& ChunkStore; + RemoteProjectStore& RemoteStore; + BuildStorageCache* OptionalCache = nullptr; + Oid CacheBuildId = Oid::Zero; + BuildStorageCache::Statistics* OptionalCacheStats = nullptr; + ProjectStore::Oplog& Oplog; + WorkerThreadPool& NetworkWorkerPool; + WorkerThreadPool& WorkerPool; + bool ForceDownload = false; + bool IgnoreMissingAttachments = false; + bool CleanOplog = false; + EPartialBlockRequestMode PartialBlockRequestMode = EPartialBlockRequestMode::All; + bool PopulateCache = false; + double StoreLatencySec = -1.0; + uint64_t StoreMaxRangeCountPerRequest = 1; + double CacheLatencySec = -1.0; + uint64_t CacheMaxRangeCountPerRequest = 1; + JobContext* OptionalJobContext = nullptr; +}; + +RemoteProjectStore::Result LoadOplog(LoadOplogContext&& Context); std::vector GetBlockHashesFromOplog(CbObjectView ContainerObject); std::vector GetBlocksFromOplog(CbObjectView ContainerObject, std::span IncludeBlockHashes); diff --git a/src/zenremotestore/jupiter/jupiterhost.cpp b/src/zenremotestore/jupiter/jupiterhost.cpp index 2583cfc84..4479c8b97 100644 --- a/src/zenremotestore/jupiter/jupiterhost.cpp +++ b/src/zenremotestore/jupiter/jupiterhost.cpp @@ -59,13 +59,22 @@ TestJupiterEndpoint(std::string_view BaseUrl, const bool AssumeHttp2, const bool HttpClient::Response TestResponse = TestHttpClient.Get("/health/live"); if (TestResponse.IsSuccess()) { + // TODO: dan.engelbrecht 20260305 - replace this naive nginx detection with proper capabilites end point once it exists in Jupiter + uint64_t MaxRangeCountPerRequest = 1; + if (auto It = TestResponse.Header.Entries.find("Server"); It != TestResponse.Header.Entries.end()) + { + if (StrCaseCompare(It->second.c_str(), "nginx", 5) == 0) + { + MaxRangeCountPerRequest = 128u; + } + } LatencyTestResult LatencyResult = MeasureLatency(TestHttpClient, "/health/ready"); if (!LatencyResult.Success) { return {.Success = false, .FailureReason = LatencyResult.FailureReason}; } - return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds}; + return {.Success = true, .LatencySeconds = LatencyResult.LatencySeconds, .MaxRangeCountPerRequest = MaxRangeCountPerRequest}; } return {.Success = false, .FailureReason = TestResponse.ErrorMessage("")}; } diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index 3400cdbf5..2282a31dd 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -7,8 +7,6 @@ #include #include -#include -#include #include #include #include @@ -26,18 +24,14 @@ class BuildsRemoteStore : public RemoteProjectStore public: BuildsRemoteStore(LoggerRef InLog, const HttpClientSettings& ClientSettings, - HttpClientSettings* OptionalCacheClientSettings, std::string_view HostUrl, - std::string_view CacheUrl, const std::filesystem::path& TempFilePath, - WorkerThreadPool& CacheBackgroundWorkerPool, std::string_view Namespace, std::string_view Bucket, const Oid& BuildId, const IoBuffer& MetaData, bool ForceDisableBlocks, - bool ForceDisableTempBlocks, - bool PopulateCache) + bool ForceDisableTempBlocks) : m_Log(InLog) , m_BuildStorageHttp(HostUrl, ClientSettings) , m_BuildStorage(CreateJupiterBuildStorage(Log(), @@ -53,20 +47,8 @@ public: , m_MetaData(MetaData) , m_EnableBlocks(!ForceDisableBlocks) , m_UseTempBlocks(!ForceDisableTempBlocks) - , m_PopulateCache(PopulateCache) { m_MetaData.MakeOwned(); - if (OptionalCacheClientSettings) - { - ZEN_ASSERT(!CacheUrl.empty()); - m_BuildCacheStorageHttp = std::make_unique(CacheUrl, *OptionalCacheClientSettings); - m_BuildCacheStorage = CreateZenBuildStorageCache(*m_BuildCacheStorageHttp, - m_StorageCacheStats, - Namespace, - Bucket, - TempFilePath, - CacheBackgroundWorkerPool); - } } virtual RemoteStoreInfo GetInfo() const override @@ -75,9 +57,8 @@ public: .UseTempBlockFiles = m_UseTempBlocks, .AllowChunking = true, .ContainerName = fmt::format("{}/{}/{}", m_Namespace, m_Bucket, m_BuildId), - .Description = fmt::format("[cloud] {}{}. SessionId: {}. {}/{}/{}"sv, + .Description = fmt::format("[cloud] {}. SessionId: {}. {}/{}/{}"sv, m_BuildStorageHttp.GetBaseUri(), - m_BuildCacheStorage ? fmt::format(" (Cache: {})", m_BuildCacheStorageHttp->GetBaseUri()) : ""sv, m_BuildStorageHttp.GetSessionId(), m_Namespace, m_Bucket, @@ -86,15 +67,13 @@ public: virtual Stats GetStats() const override { - return { - .m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load() + m_StorageCacheStats.TotalBytesWritten.load(), - .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load() + m_StorageCacheStats.TotalBytesRead.load(), - .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000 + m_StorageCacheStats.TotalRequestTimeUs.load() * 1000, - .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load() + m_StorageCacheStats.TotalRequestCount.load(), - .m_PeakSentBytes = Max(m_BuildStorageStats.PeakSentBytes.load(), m_StorageCacheStats.PeakSentBytes.load()), - .m_PeakReceivedBytes = Max(m_BuildStorageStats.PeakReceivedBytes.load(), m_StorageCacheStats.PeakReceivedBytes.load()), - .m_PeakBytesPerSec = Max(m_BuildStorageStats.PeakBytesPerSec.load(), m_StorageCacheStats.PeakBytesPerSec.load()), - }; + return {.m_SentBytes = m_BuildStorageStats.TotalBytesWritten.load(), + .m_ReceivedBytes = m_BuildStorageStats.TotalBytesRead.load(), + .m_RequestTimeNS = m_BuildStorageStats.TotalRequestTimeUs.load() * 1000, + .m_RequestCount = m_BuildStorageStats.TotalRequestCount.load(), + .m_PeakSentBytes = m_BuildStorageStats.PeakSentBytes.load(), + .m_PeakReceivedBytes = m_BuildStorageStats.PeakReceivedBytes.load(), + .m_PeakBytesPerSec = m_BuildStorageStats.PeakBytesPerSec.load()}; } virtual bool GetExtendedStats(ExtendedStats& OutStats) const override @@ -109,11 +88,6 @@ public: } Result = true; } - if (m_BuildCacheStorage) - { - OutStats.m_ReceivedBytesPerSource.insert_or_assign("Cache", m_StorageCacheStats.TotalBytesRead); - Result = true; - } return Result; } @@ -462,11 +436,14 @@ public: return Result; } - virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override { std::unique_ptr Output(CreateStandardLogOutput(Log())); ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); + ZEN_ASSERT(OptionalCache == nullptr || CacheBuildId == m_BuildId); GetBlockDescriptionsResult Result; Stopwatch Timer; @@ -476,7 +453,7 @@ public: { Result.Blocks = zen::GetBlockDescriptions(*Output, *m_BuildStorage, - m_BuildCacheStorage.get(), + OptionalCache, m_BuildId, BlockHashes, /*AttemptFallback*/ false, @@ -506,49 +483,7 @@ public: return Result; } - virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override - { - AttachmentExistsInCacheResult Result; - Stopwatch Timer; - auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); - try - { - const std::vector CacheExistsResult = - m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes); - - if (CacheExistsResult.size() == RawHashes.size()) - { - Result.HasBody.reserve(CacheExistsResult.size()); - for (size_t BlobIndex = 0; BlobIndex < CacheExistsResult.size(); BlobIndex++) - { - Result.HasBody.push_back(CacheExistsResult[BlobIndex].HasBody); - } - } - } - catch (const HttpClientError& Ex) - { - Result.ErrorCode = MakeErrorCode(Ex); - Result.Reason = fmt::format("Remote cache: Failed finding known blobs for {}/{}/{}/{}. Reason: '{}'", - m_BuildStorageHttp.GetBaseUri(), - m_Namespace, - m_Bucket, - m_BuildId, - Ex.what()); - } - catch (const std::exception& Ex) - { - Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("Remote cache: Failed finding known blobs for {}/{}/{}/{}. Reason: '{}'", - m_BuildStorageHttp.GetBaseUri(), - m_Namespace, - m_Bucket, - m_BuildId, - Ex.what()); - } - return Result; - } - - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { ZEN_ASSERT(m_OplogBuildPartId != Oid::Zero); @@ -558,25 +493,7 @@ public: try { - if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) - { - IoBuffer CachedBlob = m_BuildCacheStorage->GetBuildBlob(m_BuildId, RawHash); - if (CachedBlob) - { - Result.Bytes = std::move(CachedBlob); - } - } - if (!Result.Bytes && SourceMode != ESourceMode::kCacheOnly) - { - Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); - if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) - { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - RawHash, - Result.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(Result.Bytes))); - } - } + Result.Bytes = m_BuildStorage->GetBuildBlob(m_BuildId, RawHash); } catch (const HttpClientError& Ex) { @@ -605,45 +522,20 @@ public: } virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, - std::span> Ranges, - ESourceMode SourceMode) override + std::span> Ranges) override { + ZEN_ASSERT(!Ranges.empty()); LoadAttachmentRangesResult Result; Stopwatch Timer; auto _ = MakeGuard([&Timer, &Result]() { Result.ElapsedSeconds = Timer.GetElapsedTimeUs() / 1000000.0; }); try { - if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) + BuildStorageBase::BuildBlobRanges BlobRanges = m_BuildStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); + if (BlobRanges.PayloadBuffer) { - BuildStorageCache::BuildBlobRanges BlobRanges = m_BuildCacheStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); - if (BlobRanges.PayloadBuffer) - { - Result.Bytes = std::move(BlobRanges.PayloadBuffer); - Result.Ranges = std::move(BlobRanges.Ranges); - } - } - if (!Result.Bytes && SourceMode != ESourceMode::kCacheOnly) - { - BuildStorageBase::BuildBlobRanges BlobRanges = m_BuildStorage->GetBuildBlobRanges(m_BuildId, RawHash, Ranges); - if (BlobRanges.PayloadBuffer) - { - Result.Bytes = std::move(BlobRanges.PayloadBuffer); - Result.Ranges = std::move(BlobRanges.Ranges); - - if (Result.Ranges.empty()) - { - // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3/Replicated - // Upload to cache (if enabled) - if (m_BuildCacheStorage && Result.Bytes && m_PopulateCache) - { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - RawHash, - Result.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(Result.Bytes))); - } - } - } + Result.Bytes = std::move(BlobRanges.PayloadBuffer); + Result.Ranges = std::move(BlobRanges.Ranges); } } catch (const HttpClientError& Ex) @@ -674,7 +566,7 @@ public: return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override { LoadAttachmentsResult Result; Stopwatch Timer; @@ -682,67 +574,20 @@ public: std::vector AttachmentsLeftToFind = RawHashes; - if (m_BuildCacheStorage && SourceMode != ESourceMode::kHostOnly) - { - std::vector ExistCheck = m_BuildCacheStorage->BlobsExists(m_BuildId, RawHashes); - if (ExistCheck.size() == RawHashes.size()) - { - AttachmentsLeftToFind.clear(); - for (size_t BlobIndex = 0; BlobIndex < RawHashes.size(); BlobIndex++) - { - const IoHash& Hash = RawHashes[BlobIndex]; - const BuildStorageCache::BlobExistsResult& BlobExists = ExistCheck[BlobIndex]; - if (BlobExists.HasBody) - { - IoBuffer CachedPayload = m_BuildCacheStorage->GetBuildBlob(m_BuildId, Hash); - if (CachedPayload) - { - Result.Chunks.emplace_back( - std::pair{Hash, - CompressedBuffer::FromCompressedNoValidate(std::move(CachedPayload))}); - } - else - { - AttachmentsLeftToFind.push_back(Hash); - } - } - else - { - AttachmentsLeftToFind.push_back(Hash); - } - } - } - } - for (const IoHash& Hash : AttachmentsLeftToFind) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; } ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast(ChunkResult.ElapsedSeconds * 1000))); - if (m_BuildCacheStorage && ChunkResult.Bytes && m_PopulateCache) - { - m_BuildCacheStorage->PutBuildBlob(m_BuildId, - Hash, - ChunkResult.Bytes.GetContentType(), - CompositeBuffer(SharedBuffer(ChunkResult.Bytes))); - } Result.Chunks.emplace_back( std::pair{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); } return Result; } - virtual void Flush() override - { - if (m_BuildCacheStorage) - { - m_BuildCacheStorage->Flush(100, [](intptr_t) { return false; }); - } - } - private: static int MakeErrorCode(const HttpClientError& Ex) { @@ -759,10 +604,6 @@ private: HttpClient m_BuildStorageHttp; std::unique_ptr m_BuildStorage; - BuildStorageCache::Statistics m_StorageCacheStats; - std::unique_ptr m_BuildCacheStorageHttp; - std::unique_ptr m_BuildCacheStorage; - const std::string m_Namespace; const std::string m_Bucket; const Oid m_BuildId; @@ -771,125 +612,34 @@ private: const bool m_EnableBlocks = true; const bool m_UseTempBlocks = true; const bool m_AllowRedirect = false; - const bool m_PopulateCache = true; }; std::shared_ptr -CreateJupiterBuildsRemoteStore(LoggerRef InLog, - const BuildsRemoteStoreOptions& Options, - const std::filesystem::path& TempFilePath, - bool Quiet, - bool Unattended, - bool Hidden, - WorkerThreadPool& CacheBackgroundWorkerPool, - double& OutHostLatencySec, - double& OutCacheLatencySec) +CreateJupiterBuildsRemoteStore(LoggerRef InLog, + const BuildStorageResolveResult& ResolveResult, + std::function&& TokenProvider, + const BuildsRemoteStoreOptions& Options, + const std::filesystem::path& TempFilePath) { - std::string Host = Options.Host; - if (!Host.empty() && Host.find("://"sv) == std::string::npos) - { - // Assume https URL - Host = fmt::format("https://{}"sv, Host); - } - std::string OverrideUrl = Options.OverrideHost; - if (!OverrideUrl.empty() && OverrideUrl.find("://"sv) == std::string::npos) - { - // Assume https URL - OverrideUrl = fmt::format("https://{}"sv, OverrideUrl); - } - std::string ZenHost = Options.ZenHost; - if (!ZenHost.empty() && ZenHost.find("://"sv) == std::string::npos) - { - // Assume https URL - ZenHost = fmt::format("https://{}"sv, ZenHost); - } - - // 1) openid-provider if given (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider - // 2) Access token as parameter in request - // 3) Environment variable (different win vs linux/mac) - // 4) Default openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider - - std::function TokenProvider; - if (!Options.OpenIdProvider.empty()) - { - TokenProvider = httpclientauth::CreateFromOpenIdProvider(Options.AuthManager, Options.OpenIdProvider); - } - else if (!Options.AccessToken.empty()) - { - TokenProvider = httpclientauth::CreateFromStaticToken(Options.AccessToken); - } - else if (!Options.OidcExePath.empty()) - { - if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(Options.OidcExePath, - Host.empty() ? OverrideUrl : Host, - Quiet, - Unattended, - Hidden); - TokenProviderMaybe) - { - TokenProvider = TokenProviderMaybe.value(); - } - } - - if (!TokenProvider) - { - TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(Options.AuthManager); - } - - BuildStorageResolveResult ResolveRes; - { - HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", - .AccessTokenProvider = TokenProvider, - .AssumeHttp2 = Options.AssumeHttp2, - .AllowResume = true, - .RetryCount = 2}; - - std::unique_ptr Output(CreateStandardLogOutput(InLog)); - - ResolveRes = - ResolveBuildStorage(*Output, ClientSettings, Host, OverrideUrl, ZenHost, ZenCacheResolveMode::Discovery, /*Verbose*/ false); - } - HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", .ConnectTimeout = std::chrono::milliseconds(3000), .Timeout = std::chrono::milliseconds(1800000), .AccessTokenProvider = std::move(TokenProvider), - .AssumeHttp2 = ResolveRes.HostAssumeHttp2, + .AssumeHttp2 = ResolveResult.Cloud.AssumeHttp2, .AllowResume = true, .RetryCount = 4, .MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize}; - std::unique_ptr CacheClientSettings; - - if (!ResolveRes.CacheUrl.empty()) - { - CacheClientSettings = - std::make_unique(HttpClientSettings{.LogCategory = "httpcacheclient", - .ConnectTimeout = std::chrono::milliseconds{3000}, - .Timeout = std::chrono::milliseconds{30000}, - .AssumeHttp2 = ResolveRes.CacheAssumeHttp2, - .AllowResume = true, - .RetryCount = 0, - .MaximumInMemoryDownloadSize = Options.MaximumInMemoryDownloadSize}); - } - std::shared_ptr RemoteStore = std::make_shared(InLog, ClientSettings, - CacheClientSettings.get(), - ResolveRes.HostUrl, - ResolveRes.CacheUrl, + ResolveResult.Cloud.Address, TempFilePath, - CacheBackgroundWorkerPool, Options.Namespace, Options.Bucket, Options.BuildId, Options.MetaData, Options.ForceDisableBlocks, - Options.ForceDisableTempBlocks, - Options.PopulateCache); - - OutHostLatencySec = ResolveRes.HostLatencySec; - OutCacheLatencySec = ResolveRes.CacheLatencySec; + Options.ForceDisableTempBlocks); return RemoteStore; } diff --git a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp index f950fd46c..bb21de12c 100644 --- a/src/zenremotestore/projectstore/fileremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/fileremoteprojectstore.cpp @@ -7,8 +7,12 @@ #include #include #include +#include #include #include +#include + +#include namespace zen { @@ -74,9 +78,11 @@ public: virtual SaveResult SaveContainer(const IoBuffer& Payload) override { - Stopwatch Timer; SaveResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + { CbObject ContainerObject = LoadCompactBinaryObject(Payload); @@ -87,6 +93,10 @@ public: { Result.Needs.insert(AttachmentHash); } + else if (std::filesystem::path AttachmentMetaPath = GetAttachmentMetaPath(AttachmentHash); IsFile(AttachmentMetaPath)) + { + BasicFile TouchIt(AttachmentMetaPath, BasicFile::Mode::kWrite); + } }); } @@ -112,14 +122,18 @@ public: Result.Reason = fmt::format("Failed saving oplog container to '{}'. Reason: {}", ContainerPath, Ex.what()); } AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } - virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash, ChunkBlockDescription&&) override + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, + const IoHash& RawHash, + ChunkBlockDescription&& BlockDescription) override { - Stopwatch Timer; - SaveAttachmentResult Result; + SaveAttachmentResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); if (!IsFile(ChunkPath)) { @@ -142,14 +156,33 @@ public: Result.Reason = fmt::format("Failed saving oplog attachment to '{}'. Reason: {}", ChunkPath, Ex.what()); } } + if (!Result.ErrorCode && BlockDescription.BlockHash != IoHash::Zero) + { + try + { + std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash); + CbObject MetaData = BuildChunkBlockDescription(BlockDescription, {}); + SharedBuffer MetaBuffer = MetaData.GetBuffer(); + BasicFile MetaFile; + MetaFile.Open(MetaPath, BasicFile::Mode::kTruncate); + MetaFile.Write(MetaBuffer.GetView(), 0); + } + catch (const std::exception& Ex) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("Failed saving block description to '{}'. Reason: {}", RawHash, Ex.what()); + } + } AddStats(Payload.GetSize(), 0, Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } virtual SaveAttachmentsResult SaveAttachments(const std::vector& Chunks) override { + SaveAttachmentsResult Result; + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); for (const SharedBuffer& Chunk : Chunks) { @@ -157,12 +190,10 @@ public: SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash(), {}); if (ChunkResult.ErrorCode) { - ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - return SaveAttachmentsResult{ChunkResult}; + Result = SaveAttachmentsResult{ChunkResult}; + break; } } - SaveAttachmentsResult Result; - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } @@ -172,21 +203,60 @@ public: virtual GetKnownBlocksResult GetKnownBlocks() override { + Stopwatch Timer; if (m_OptionalBaseName.empty()) { - return GetKnownBlocksResult{{.ErrorCode = static_cast(HttpResponseCode::NoContent)}}; + size_t MaxBlockCount = 10000; + + GetKnownBlocksResult Result; + + DirectoryContent Content; + GetDirectoryContent( + m_OutputPath, + DirectoryContentFlags::IncludeFiles | DirectoryContentFlags::Recursive | DirectoryContentFlags::IncludeModificationTick, + Content); + std::vector RecentOrder(Content.Files.size()); + std::iota(RecentOrder.begin(), RecentOrder.end(), 0u); + std::sort(RecentOrder.begin(), RecentOrder.end(), [&Content](size_t Lhs, size_t Rhs) { + return Content.FileModificationTicks[Lhs] > Content.FileModificationTicks[Rhs]; + }); + + for (size_t FileIndex : RecentOrder) + { + std::filesystem::path MetaPath = Content.Files[FileIndex]; + if (MetaPath.extension() == MetaExtension) + { + IoBuffer MetaFile = ReadFile(MetaPath).Flatten(); + CbValidateError Err; + CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err); + if (Err == CbValidateError::None) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject); + if (Description.BlockHash != IoHash::Zero) + { + Result.Blocks.emplace_back(std::move(Description)); + if (Result.Blocks.size() == MaxBlockCount) + { + break; + } + } + } + } + } + + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + return Result; } LoadContainerResult LoadResult = LoadContainer(m_OptionalBaseName); if (LoadResult.ErrorCode) { return GetKnownBlocksResult{LoadResult}; } - Stopwatch Timer; std::vector BlockHashes = GetBlockHashesFromOplog(LoadResult.ContainerObject); if (BlockHashes.empty()) { return GetKnownBlocksResult{{.ErrorCode = static_cast(HttpResponseCode::NoContent), - .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; + .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; } std::vector ExistingBlockHashes; for (const IoHash& RawHash : BlockHashes) @@ -200,15 +270,15 @@ public: if (ExistingBlockHashes.empty()) { return GetKnownBlocksResult{{.ErrorCode = static_cast(HttpResponseCode::NoContent), - .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; + .ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; } std::vector ThinKnownBlocks = GetBlocksFromOplog(LoadResult.ContainerObject, ExistingBlockHashes); - const size_t KnowBlockCount = ThinKnownBlocks.size(); + const size_t KnownBlockCount = ThinKnownBlocks.size(); - GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeUs() * 1000}}; - Result.Blocks.resize(KnowBlockCount); - for (size_t BlockIndex = 0; BlockIndex < KnowBlockCount; BlockIndex++) + GetKnownBlocksResult Result{{.ElapsedSeconds = LoadResult.ElapsedSeconds + Timer.GetElapsedTimeMs() / 1000.0}}; + Result.Blocks.resize(KnownBlockCount); + for (size_t BlockIndex = 0; BlockIndex < KnownBlockCount; BlockIndex++) { Result.Blocks[BlockIndex].BlockHash = ThinKnownBlocks[BlockIndex].BlockHash; Result.Blocks[BlockIndex].ChunkRawHashes = std::move(ThinKnownBlocks[BlockIndex].ChunkRawHashes); @@ -217,87 +287,141 @@ public: return Result; } - virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override { - ZEN_UNUSED(BlockHashes); - return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; - } + GetBlockDescriptionsResult Result; - virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override - { - return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; - } + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override - { - Stopwatch Timer; - LoadAttachmentResult Result; - if (SourceMode != ESourceMode::kCacheOnly) + Result.Blocks.reserve(BlockHashes.size()); + + uint64_t ByteCount = 0; + + std::vector UnorderedList; { - std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); - if (!IsFile(ChunkPath)) + if (OptionalCache) { - Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); - Result.Reason = - fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - return Result; + std::vector CacheBlockMetadatas = OptionalCache->GetBlobMetadatas(CacheBuildId, BlockHashes); + for (const CbObject& BlockObject : CacheBlockMetadatas) + { + ByteCount += BlockObject.GetSize(); + } + UnorderedList = ParseBlockMetadatas(CacheBlockMetadatas); } + + tsl::robin_map BlockDescriptionLookup; + BlockDescriptionLookup.reserve(BlockHashes.size()); + for (size_t DescriptionIndex = 0; DescriptionIndex < UnorderedList.size(); DescriptionIndex++) { - BasicFile ChunkFile; - ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); - Result.Bytes = ChunkFile.ReadAll(); + const ChunkBlockDescription& Description = UnorderedList[DescriptionIndex]; + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, DescriptionIndex); + } + + if (UnorderedList.size() < BlockHashes.size()) + { + for (const IoHash& RawHash : BlockHashes) + { + if (!BlockDescriptionLookup.contains(RawHash)) + { + std::filesystem::path MetaPath = GetAttachmentMetaPath(RawHash); + IoBuffer MetaFile = ReadFile(MetaPath).Flatten(); + ByteCount += MetaFile.GetSize(); + CbValidateError Err; + CbObject ValidatedObject = ValidateAndReadCompactBinaryObject(std::move(MetaFile), Err); + if (Err == CbValidateError::None) + { + ChunkBlockDescription Description = ParseChunkBlockDescription(ValidatedObject); + if (Description.BlockHash != IoHash::Zero) + { + BlockDescriptionLookup.insert_or_assign(Description.BlockHash, UnorderedList.size()); + UnorderedList.emplace_back(std::move(Description)); + } + } + } + } + } + + Result.Blocks.reserve(UnorderedList.size()); + for (const IoHash& RawHash : BlockHashes) + { + if (auto It = BlockDescriptionLookup.find(RawHash); It != BlockDescriptionLookup.end()) + { + Result.Blocks.emplace_back(std::move(UnorderedList[It->second])); + } } } + AddStats(0, ByteCount, Timer.GetElapsedTimeUs() * 1000); + return Result; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + LoadAttachmentResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!IsFile(ChunkPath)) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); + Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); + return Result; + } + { + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); + Result.Bytes = ChunkFile.ReadAll(); + } AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, - std::span> Ranges, - ESourceMode SourceMode) override + std::span> Ranges) override { - Stopwatch Timer; + ZEN_ASSERT(!Ranges.empty()); LoadAttachmentRangesResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); - if (!IsFile(ChunkPath)) - { - Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); - Result.Reason = - fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - return Result; - } - { - BasicFile ChunkFile; - ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); - uint64_t Start = Ranges.front().first; - uint64_t Length = Ranges.back().first + Ranges.back().second - Ranges.front().first; + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); - Result.Bytes = ChunkFile.ReadRange(Start, Length); - Result.Ranges.reserve(Ranges.size()); - for (const std::pair& Range : Ranges) - { - Result.Ranges.push_back(std::make_pair(Range.first - Start, Range.second)); - } + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!IsFile(ChunkPath)) + { + Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); + Result.Reason = fmt::format("Failed loading oplog attachment from '{}'. Reason: 'The file does not exist'", ChunkPath.string()); + return Result; + } + { + uint64_t Start = Ranges.front().first; + uint64_t Length = Ranges.back().first + Ranges.back().second - Ranges.front().first; + Result.Bytes = IoBufferBuilder::MakeFromFile(ChunkPath, Start, Length); + Result.Ranges.reserve(Ranges.size()); + for (const std::pair& Range : Ranges) + { + Result.Ranges.push_back(std::make_pair(Range.first - Start, Range.second)); } } - AddStats(0, Result.Bytes.GetSize(), Timer.GetElapsedTimeUs() * 1000); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + AddStats(0, + std::accumulate(Result.Ranges.begin(), + Result.Ranges.end(), + uint64_t(0), + [](uint64_t Current, const std::pair& Value) { return Current + Value.second; }), + Timer.GetElapsedTimeUs() * 1000); return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override { Stopwatch Timer; LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); if (ChunkResult.ErrorCode) { ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; @@ -310,20 +434,20 @@ public: return Result; } - virtual void Flush() override {} - private: LoadContainerResult LoadContainer(const std::string& Name) { - Stopwatch Timer; - LoadContainerResult Result; + LoadContainerResult Result; + + Stopwatch Timer; + auto _ = MakeGuard([&Result, &Timer]() { Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; }); + std::filesystem::path SourcePath = m_OutputPath; SourcePath.append(Name); if (!IsFile(SourcePath)) { Result.ErrorCode = gsl::narrow(HttpResponseCode::NotFound); Result.Reason = fmt::format("Failed loading oplog container from '{}'. Reason: 'The file does not exist'", SourcePath.string()); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } IoBuffer ContainerPayload; @@ -337,18 +461,16 @@ private: if (Result.ContainerObject = ValidateAndReadCompactBinaryObject(std::move(ContainerPayload), ValidateResult); ValidateResult != CbValidateError::None || !Result.ContainerObject) { - Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); - Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')", - SourcePath.string(), - ToString(ValidateResult)); - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; + Result.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("The file {} is not formatted as a compact binary object ('{}')", + SourcePath.string(), + ToString(ValidateResult)); return Result; } - Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; return Result; } - std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const + std::filesystem::path GetAttachmentBasePath(const IoHash& RawHash) const { ExtendablePathBuilder<128> ShardedPath; ShardedPath.Append(m_OutputPath.c_str()); @@ -367,6 +489,19 @@ private: return ShardedPath.ToPath(); } + static constexpr std::string_view BlobExtension = ".blob"; + static constexpr std::string_view MetaExtension = ".meta"; + + std::filesystem::path GetAttachmentPath(const IoHash& RawHash) + { + return GetAttachmentBasePath(RawHash).replace_extension(BlobExtension); + } + + std::filesystem::path GetAttachmentMetaPath(const IoHash& RawHash) + { + return GetAttachmentBasePath(RawHash).replace_extension(MetaExtension); + } + void AddStats(uint64_t UploadedBytes, uint64_t DownloadedBytes, uint64_t ElapsedNS) { m_SentBytes.fetch_add(UploadedBytes); diff --git a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp index 514484f30..5b456cb4c 100644 --- a/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/jupiterremoteprojectstore.cpp @@ -212,73 +212,64 @@ public: return Result; } - virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override { - ZEN_UNUSED(BlockHashes); + ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId); return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; } - virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override - { - return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; - } - - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { LoadAttachmentResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); - JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); - AddStats(GetResult); + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + AddStats(GetResult); - Result = {ConvertResult(GetResult), std::move(GetResult.Response)}; - if (GetResult.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", - m_JupiterClient->ServiceUrl(), - m_Namespace, - RawHash, - Result.Reason); - } + Result = {ConvertResult(GetResult), std::move(GetResult.Response)}; + if (GetResult.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", + m_JupiterClient->ServiceUrl(), + m_Namespace, + RawHash, + Result.Reason); } return Result; } virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, - std::span> Ranges, - ESourceMode SourceMode) override + std::span> Ranges) override { + ZEN_ASSERT(!Ranges.empty()); LoadAttachmentRangesResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); - JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); - AddStats(GetResult); + JupiterSession Session(m_JupiterClient->Logger(), m_JupiterClient->Client(), m_AllowRedirect); + JupiterResult GetResult = Session.GetCompressedBlob(m_Namespace, RawHash, m_TempFilePath); + AddStats(GetResult); - Result = LoadAttachmentRangesResult{ConvertResult(GetResult), std::move(GetResult.Response)}; - if (GetResult.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", - m_JupiterClient->ServiceUrl(), - m_Namespace, - RawHash, - Result.Reason); - } - else - { - Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); - } + Result = LoadAttachmentRangesResult{ConvertResult(GetResult), std::move(GetResult.Response)}; + if (GetResult.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}. Reason: '{}'", + m_JupiterClient->ServiceUrl(), + m_Namespace, + RawHash, + Result.Reason); + } + else + { + Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); } return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override { LoadAttachmentsResult Result; for (const IoHash& Hash : RawHashes) { - LoadAttachmentResult ChunkResult = LoadAttachment(Hash, SourceMode); + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); if (ChunkResult.ErrorCode) { return LoadAttachmentsResult{ChunkResult}; @@ -290,8 +281,6 @@ public: return Result; } - virtual void Flush() override {} - private: LoadContainerResult LoadContainer(const IoHash& Key) { diff --git a/src/zenremotestore/projectstore/projectstoreoperations.cpp b/src/zenremotestore/projectstore/projectstoreoperations.cpp index becac3d4c..36dc4d868 100644 --- a/src/zenremotestore/projectstore/projectstoreoperations.cpp +++ b/src/zenremotestore/projectstore/projectstoreoperations.cpp @@ -426,19 +426,19 @@ ProjectStoreOperationDownloadAttachments::Execute() auto GetBuildBlob = [this](const IoHash& RawHash, const std::filesystem::path& OutputPath) { IoBuffer Payload; - if (m_Storage.BuildCacheStorage) + if (m_Storage.CacheStorage) { - Payload = m_Storage.BuildCacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); + Payload = m_Storage.CacheStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); } if (!Payload) { Payload = m_Storage.BuildStorage->GetBuildBlob(m_State.GetBuildId(), RawHash); - if (m_Storage.BuildCacheStorage && m_Options.PopulateCache) + if (m_Storage.CacheStorage && m_Options.PopulateCache) { - m_Storage.BuildCacheStorage->PutBuildBlob(m_State.GetBuildId(), - RawHash, - Payload.GetContentType(), - CompositeBuffer(SharedBuffer(Payload))); + m_Storage.CacheStorage->PutBuildBlob(m_State.GetBuildId(), + RawHash, + Payload.GetContentType(), + CompositeBuffer(SharedBuffer(Payload))); } } uint64_t PayloadSize = Payload.GetSize(); diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index c8c5f201d..d5c6286a8 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -124,14 +125,17 @@ namespace remotestore_impl { return OptionalContext->IsCancelled(); } - std::string GetStats(const RemoteProjectStore::Stats& Stats, uint64_t ElapsedWallTimeMS) + std::string GetStats(const RemoteProjectStore::Stats& Stats, + const BuildStorageCache::Statistics* OptionalCacheStats, + uint64_t ElapsedWallTimeMS) { - return fmt::format( - "Sent: {} ({}bits/s) Recv: {} ({}bits/s)", - NiceBytes(Stats.m_SentBytes), - NiceNum(ElapsedWallTimeMS > 0u ? static_cast((Stats.m_SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u), - NiceBytes(Stats.m_ReceivedBytes), - NiceNum(ElapsedWallTimeMS > 0u ? static_cast((Stats.m_ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u)); + uint64_t SentBytes = Stats.m_SentBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesWritten.load() : 0); + uint64_t ReceivedBytes = Stats.m_ReceivedBytes + (OptionalCacheStats ? OptionalCacheStats->TotalBytesRead.load() : 0); + return fmt::format("Sent: {} ({}bits/s) Recv: {} ({}bits/s)", + NiceBytes(SentBytes), + NiceNum(ElapsedWallTimeMS > 0u ? static_cast((SentBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u), + NiceBytes(ReceivedBytes), + NiceNum(ElapsedWallTimeMS > 0u ? static_cast((ReceivedBytes * 8 * 1000) / ElapsedWallTimeMS) : 0u)); } void LogRemoteStoreStatsDetails(const RemoteProjectStore::Stats& Stats) @@ -269,12 +273,7 @@ namespace remotestore_impl { JobContext* m_OptionalContext; }; - void DownloadAndSaveBlockChunks(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, + void DownloadAndSaveBlockChunks(LoadOplogContext& Context, Latch& AttachmentsDownloadLatch, Latch& AttachmentsWriteLatch, AsyncRemoteResult& RemoteResult, @@ -285,10 +284,8 @@ namespace remotestore_impl { std::vector&& NeededChunkIndexes) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( - [&RemoteStore, - &ChunkStore, - &WorkerPool, + Context.NetworkWorkerPool.ScheduleWork( + [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &RemoteResult, @@ -296,9 +293,7 @@ namespace remotestore_impl { NeededChunkIndexes = std::move(NeededChunkIndexes), &Info, &LoadAttachmentsTimer, - &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext]() { + &DownloadStartMS]() { ZEN_TRACE_CPU("DownloadBlockChunks"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -317,16 +312,16 @@ namespace remotestore_impl { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks); + RemoteProjectStore::LoadAttachmentsResult Result = Context.RemoteStore.LoadAttachments(Chunks); if (Result.ErrorCode) { - ReportMessage(OptionalContext, + ReportMessage(Context.OptionalJobContext, fmt::format("Failed to load attachments with {} chunks ({}): {}", Chunks.size(), RemoteResult.GetError(), RemoteResult.GetErrorReason())); Info.MissingAttachmentCount.fetch_add(1); - if (IgnoreMissingAttachments) + if (Context.IgnoreMissingAttachments) { RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); } @@ -338,7 +333,7 @@ namespace remotestore_impl { uint64_t ChunkSize = It.second.GetCompressedSize(); Info.AttachmentBytesDownloaded.fetch_add(ChunkSize); } - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loaded {} bulk attachments in {}", Chunks.size(), NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000)))); @@ -347,8 +342,8 @@ namespace remotestore_impl { return; } AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&AttachmentsWriteLatch, &RemoteResult, &Info, &ChunkStore, Chunks = std::move(Result.Chunks)]() { + Context.WorkerPool.ScheduleWork( + [&AttachmentsWriteLatch, &RemoteResult, &Info, &Context, Chunks = std::move(Result.Chunks)]() { auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); if (RemoteResult.IsError()) { @@ -369,7 +364,9 @@ namespace remotestore_impl { WriteRawHashes.push_back(It.first); } std::vector InsertResults = - ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes, CidStore::InsertMode::kCopyOnly); + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, + WriteRawHashes, + CidStore::InsertMode::kCopyOnly); for (size_t Index = 0; Index < InsertResults.size(); Index++) { @@ -400,12 +397,7 @@ namespace remotestore_impl { WorkerThreadPool::EMode::EnableBacklog); }; - void DownloadAndSaveBlock(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, + void DownloadAndSaveBlock(LoadOplogContext& Context, Latch& AttachmentsDownloadLatch, Latch& AttachmentsWriteLatch, AsyncRemoteResult& RemoteResult, @@ -418,19 +410,14 @@ namespace remotestore_impl { uint32_t RetriesLeft) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( + Context.NetworkWorkerPool.ScheduleWork( [&AttachmentsDownloadLatch, &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, + &Context, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, RetriesLeft, BlockHash = IoHash(BlockHash), &AllNeededPartialChunkHashesLookup, @@ -446,52 +433,65 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); - if (BlockResult.ErrorCode) + + IoBuffer BlobBuffer; + if (Context.OptionalCache) { - ReportMessage(OptionalContext, - fmt::format("Failed to download block attachment {} ({}): {}", - BlockHash, - RemoteResult.GetError(), - RemoteResult.GetErrorReason())); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) - { - RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); - } - return; + BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, BlockHash); } - if (RemoteResult.IsError()) + + if (!BlobBuffer) { - return; + RemoteProjectStore::LoadAttachmentResult BlockResult = Context.RemoteStore.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode) + { + ReportMessage(Context.OptionalJobContext, + fmt::format("Failed to download block attachment {} ({}): {}", + BlockHash, + BlockResult.Reason, + BlockResult.Text)); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + } + return; + } + if (RemoteResult.IsError()) + { + return; + } + BlobBuffer = std::move(BlockResult.Bytes); + ZEN_DEBUG("Loaded block attachment '{}' in {} ({})", + BlockHash, + NiceTimeSpanMs(static_cast(BlockResult.ElapsedSeconds * 1000)), + NiceBytes(BlobBuffer.Size())); + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + BlockHash, + BlobBuffer.GetContentType(), + CompositeBuffer(SharedBuffer(BlobBuffer))); + } } - uint64_t BlockSize = BlockResult.Bytes.GetSize(); + uint64_t BlockSize = BlobBuffer.GetSize(); Info.AttachmentBlocksDownloaded.fetch_add(1); - ZEN_DEBUG("Loaded block attachment '{}' in {} ({})", - BlockHash, - NiceTimeSpanMs(static_cast(BlockResult.ElapsedSeconds * 1000)), - NiceBytes(BlockSize)); Info.AttachmentBlockBytesDownloaded.fetch_add(BlockSize); AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( + Context.WorkerPool.ScheduleWork( [&AttachmentsDownloadLatch, &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, + &Context, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, RetriesLeft, BlockHash = IoHash(BlockHash), &AllNeededPartialChunkHashesLookup, ChunkDownloadedFlags, - Bytes = std::move(BlockResult.Bytes)]() { + Bytes = std::move(BlobBuffer)]() { auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); if (RemoteResult.IsError()) { @@ -569,7 +569,7 @@ namespace remotestore_impl { if (!WriteAttachmentBuffers.empty()) { std::vector Results = - ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); for (size_t Index = 0; Index < Results.size(); Index++) { const CidStore::InsertResult& Result = Results[Index]; @@ -598,14 +598,9 @@ namespace remotestore_impl { { if (RetriesLeft > 0) { - ReportMessage(OptionalContext, fmt::format("{}, retrying download", ErrorString)); - - return DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + ReportMessage(Context.OptionalJobContext, fmt::format("{}, retrying download", ErrorString)); + + return DownloadAndSaveBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -619,7 +614,7 @@ namespace remotestore_impl { } else { - ReportMessage(OptionalContext, ErrorString); + ReportMessage(Context.OptionalJobContext, ErrorString); RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), ErrorString, {}); return; } @@ -637,28 +632,29 @@ namespace remotestore_impl { catch (const std::exception& Ex) { RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), - fmt::format("Failed to block attachment {}", BlockHash), + fmt::format("Failed to download block attachment {}", BlockHash), Ex.what()); } }, WorkerThreadPool::EMode::EnableBacklog); }; - bool DownloadPartialBlock(RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - AsyncRemoteResult& RemoteResult, - DownloadInfo& Info, - double& DownloadTimeSeconds, - const ChunkBlockDescription& BlockDescription, - bool BlockExistsInCache, - std::span BlockRangeDescriptors, - size_t BlockRangeIndexStart, - size_t BlockRangeCount, + void DownloadPartialBlock(LoadOplogContext& Context, + AsyncRemoteResult& RemoteResult, + DownloadInfo& Info, + double& DownloadTimeSeconds, + const ChunkBlockDescription& BlockDescription, + bool BlockExistsInCache, + std::span BlockRangeDescriptors, + size_t BlockRangeIndexStart, + size_t BlockRangeCount, std::function> OffsetAndLengths)>&& OnDownloaded) { + ZEN_ASSERT(Context.StoreMaxRangeCountPerRequest != 0); + ZEN_ASSERT(BlockExistsInCache == false || Context.CacheMaxRangeCountPerRequest != 0); + std::vector> Ranges; Ranges.reserve(BlockRangeDescriptors.size()); for (size_t BlockRangeIndex = BlockRangeIndexStart; BlockRangeIndex < BlockRangeIndexStart + BlockRangeCount; BlockRangeIndex++) @@ -667,65 +663,104 @@ namespace remotestore_impl { Ranges.push_back(std::make_pair(BlockRange.RangeStart, BlockRange.RangeLength)); } - if (BlockExistsInCache) - { - RemoteProjectStore::LoadAttachmentRangesResult BlockResult = - RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, Ranges, RemoteProjectStore::ESourceMode::kCacheOnly); - DownloadTimeSeconds += BlockResult.ElapsedSeconds; - if (RemoteResult.IsError()) - { - return false; - } - if (!BlockResult.ErrorCode && BlockResult.Bytes) - { - if (BlockResult.Ranges.size() != Ranges.size()) - { - throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", - Ranges.size(), - BlockDescription.BlockHash, - BlockResult.Ranges.size())); - } - OnDownloaded(std::move(BlockResult.Bytes), BlockRangeIndexStart, BlockResult.Ranges); - return true; - } - } - - const size_t MaxRangesPerRequestToJupiter = RemoteProjectStore::MaxRangeCountPerRequest; - size_t SubBlockRangeCount = BlockRangeCount; size_t SubRangeCountComplete = 0; std::span> RangesSpan(Ranges); + while (SubRangeCountComplete < SubBlockRangeCount) { if (RemoteResult.IsError()) { break; } - size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, MaxRangesPerRequestToJupiter); + size_t SubRangeStartIndex = BlockRangeIndexStart + SubRangeCountComplete; + if (BlockExistsInCache) + { + ZEN_ASSERT(Context.OptionalCache); + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.CacheMaxRangeCountPerRequest); + + if (SubRangeCount == 1) + { + // Legacy single-range path, prefer that for max compatibility + + const std::pair SubRange = RangesSpan[SubRangeCountComplete]; + Stopwatch CacheTimer; + IoBuffer PayloadBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, + BlockDescription.BlockHash, + SubRange.first, + SubRange.second); + DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0; + if (RemoteResult.IsError()) + { + break; + } + if (PayloadBuffer) + { + OnDownloaded(std::move(PayloadBuffer), + SubRangeStartIndex, + std::vector>{std::make_pair(0u, SubRange.second)}); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + else + { + auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); + + Stopwatch CacheTimer; + BuildStorageCache::BuildBlobRanges RangeBuffers = + Context.OptionalCache->GetBuildBlobRanges(Context.CacheBuildId, BlockDescription.BlockHash, SubRanges); + DownloadTimeSeconds += CacheTimer.GetElapsedTimeMs() / 1000.0; + if (RemoteResult.IsError()) + { + break; + } + if (RangeBuffers.PayloadBuffer) + { + if (RangeBuffers.Ranges.empty()) + { + SubRangeCount = Ranges.size() - SubRangeCountComplete; + OnDownloaded(std::move(RangeBuffers.PayloadBuffer), + SubRangeStartIndex, + RangesSpan.subspan(SubRangeCountComplete, SubRangeCount)); + SubRangeCountComplete += SubRangeCount; + continue; + } + else if (RangeBuffers.Ranges.size() == SubRangeCount) + { + OnDownloaded(std::move(RangeBuffers.PayloadBuffer), SubRangeStartIndex, RangeBuffers.Ranges); + SubRangeCountComplete += SubRangeCount; + continue; + } + } + } + } + + size_t SubRangeCount = Min(BlockRangeCount - SubRangeCountComplete, Context.StoreMaxRangeCountPerRequest); auto SubRanges = RangesSpan.subspan(SubRangeCountComplete, SubRangeCount); RemoteProjectStore::LoadAttachmentRangesResult BlockResult = - RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, SubRanges, RemoteProjectStore::ESourceMode::kHostOnly); + Context.RemoteStore.LoadAttachmentRanges(BlockDescription.BlockHash, SubRanges); DownloadTimeSeconds += BlockResult.ElapsedSeconds; if (RemoteResult.IsError()) { - return false; + break; } if (BlockResult.ErrorCode || !BlockResult.Bytes) { - ReportMessage(OptionalContext, + ReportMessage(Context.OptionalJobContext, fmt::format("Failed to download {} ranges from block attachment '{}' ({}): {}", SubRanges.size(), BlockDescription.BlockHash, BlockResult.ErrorCode, BlockResult.Reason)); Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); - return false; + break; } } else @@ -734,6 +769,18 @@ namespace remotestore_impl { { // Jupiter will ignore the ranges and send the whole payload if it fetches the payload from S3 // Use the whole payload for the remaining ranges + + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + BlockDescription.BlockHash, + ZenContentType::kCompressedBinary, + CompositeBuffer(std::vector{BlockResult.Bytes})); + if (RemoteResult.IsError()) + { + break; + } + } SubRangeCount = Ranges.size() - SubRangeCountComplete; OnDownloaded(std::move(BlockResult.Bytes), SubRangeStartIndex, @@ -743,10 +790,13 @@ namespace remotestore_impl { { if (BlockResult.Ranges.size() != SubRanges.size()) { - throw std::runtime_error(fmt::format("Fetching {} ranges from {} resulted in {} ranges", - SubRanges.size(), - BlockDescription.BlockHash, - BlockResult.Ranges.size())); + RemoteResult.SetError(gsl::narrow(HttpResponseCode::InternalServerError), + fmt::format("Range response for block {} contains {} ranges, expected {} ranges", + BlockDescription.BlockHash, + BlockResult.Ranges.size(), + SubRanges.size()), + ""); + break; } OnDownloaded(std::move(BlockResult.Bytes), SubRangeStartIndex, BlockResult.Ranges); } @@ -754,15 +804,9 @@ namespace remotestore_impl { SubRangeCountComplete += SubRangeCount; } - return true; } - void DownloadAndSavePartialBlock(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, + void DownloadAndSavePartialBlock(LoadOplogContext& Context, Latch& AttachmentsDownloadLatch, Latch& AttachmentsWriteLatch, AsyncRemoteResult& RemoteResult, @@ -779,19 +823,14 @@ namespace remotestore_impl { uint32_t RetriesLeft) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( + Context.NetworkWorkerPool.ScheduleWork( [&AttachmentsDownloadLatch, &AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, + &Context, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, BlockDescription, BlockExistsInCache, BlockRangeDescriptors, @@ -811,10 +850,8 @@ namespace remotestore_impl { double DownloadElapsedSeconds = 0; uint64_t DownloadedBytes = 0; - bool Success = DownloadPartialBlock( - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, + DownloadPartialBlock( + Context, RemoteResult, Info, DownloadElapsedSeconds, @@ -833,19 +870,14 @@ namespace remotestore_impl { Info.AttachmentBlocksRangesDownloaded++; AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( + Context.WorkerPool.ScheduleWork( [&AttachmentsWriteLatch, - &ChunkStore, - &RemoteStore, - &NetworkWorkerPool, - &WorkerPool, + &Context, &AttachmentsDownloadLatch, &RemoteResult, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - IgnoreMissingAttachments, - OptionalContext, BlockDescription, BlockExistsInCache, BlockRangeDescriptors, @@ -945,14 +977,9 @@ namespace remotestore_impl { { if (RetriesLeft > 0) { - ReportMessage(OptionalContext, + ReportMessage(Context.OptionalJobContext, fmt::format("{}, retrying download", ErrorString)); - return DownloadAndSavePartialBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + return DownloadAndSavePartialBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -969,9 +996,9 @@ namespace remotestore_impl { RetriesLeft - 1); } - ReportMessage(OptionalContext, ErrorString); + ReportMessage(Context.OptionalJobContext, ErrorString); Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError(gsl::narrow(HttpResponseCode::NotFound), "Malformed chunk block", @@ -998,7 +1025,7 @@ namespace remotestore_impl { if (!WriteAttachmentBuffers.empty()) { std::vector Results = - ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); + Context.ChunkStore.AddChunks(WriteAttachmentBuffers, WriteRawHashes); for (size_t Index = 0; Index < Results.size(); Index++) { const CidStore::InsertResult& Result = Results[Index]; @@ -1037,7 +1064,7 @@ namespace remotestore_impl { }, WorkerThreadPool::EMode::EnableBacklog); }); - if (Success) + if (!RemoteResult.IsError()) { ZEN_DEBUG("Loaded {} ranges from block attachment '{}' in {} ({})", BlockRangeCount, @@ -1056,12 +1083,7 @@ namespace remotestore_impl { WorkerThreadPool::EMode::EnableBacklog); }; - void DownloadAndSaveAttachment(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - bool IgnoreMissingAttachments, - JobContext* OptionalContext, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, + void DownloadAndSaveAttachment(LoadOplogContext& Context, Latch& AttachmentsDownloadLatch, Latch& AttachmentsWriteLatch, AsyncRemoteResult& RemoteResult, @@ -1071,19 +1093,15 @@ namespace remotestore_impl { const IoHash& RawHash) { AttachmentsDownloadLatch.AddCount(1); - NetworkWorkerPool.ScheduleWork( - [&RemoteStore, - &ChunkStore, - &WorkerPool, + Context.NetworkWorkerPool.ScheduleWork( + [&Context, &RemoteResult, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, RawHash, &LoadAttachmentsTimer, &DownloadStartMS, - &Info, - IgnoreMissingAttachments, - OptionalContext]() { + &Info]() { ZEN_TRACE_CPU("DownloadAttachment"); auto _ = MakeGuard([&AttachmentsDownloadLatch] { AttachmentsDownloadLatch.CountDown(); }); @@ -1095,43 +1113,52 @@ namespace remotestore_impl { { uint64_t Unset = (std::uint64_t)-1; DownloadStartMS.compare_exchange_strong(Unset, LoadAttachmentsTimer.GetElapsedTimeMs()); - RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); - if (AttachmentResult.ErrorCode) + IoBuffer BlobBuffer; + if (Context.OptionalCache) { - ReportMessage(OptionalContext, - fmt::format("Failed to download large attachment {}: '{}', error code : {}", - RawHash, - AttachmentResult.Reason, - AttachmentResult.ErrorCode)); - Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + BlobBuffer = Context.OptionalCache->GetBuildBlob(Context.CacheBuildId, RawHash); + } + if (!BlobBuffer) + { + RemoteProjectStore::LoadAttachmentResult AttachmentResult = Context.RemoteStore.LoadAttachment(RawHash); + if (AttachmentResult.ErrorCode) { - RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + ReportMessage(Context.OptionalJobContext, + fmt::format("Failed to download large attachment {}: '{}', error code : {}", + RawHash, + AttachmentResult.Reason, + AttachmentResult.ErrorCode)); + Info.MissingAttachmentCount.fetch_add(1); + if (!Context.IgnoreMissingAttachments) + { + RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + } + return; + } + BlobBuffer = std::move(AttachmentResult.Bytes); + ZEN_DEBUG("Loaded large attachment '{}' in {} ({})", + RawHash, + NiceTimeSpanMs(static_cast(AttachmentResult.ElapsedSeconds * 1000)), + NiceBytes(BlobBuffer.GetSize())); + if (Context.OptionalCache && Context.PopulateCache) + { + Context.OptionalCache->PutBuildBlob(Context.CacheBuildId, + RawHash, + BlobBuffer.GetContentType(), + CompositeBuffer(SharedBuffer(BlobBuffer))); } - return; } - uint64_t AttachmentSize = AttachmentResult.Bytes.GetSize(); - ZEN_DEBUG("Loaded large attachment '{}' in {} ({})", - RawHash, - NiceTimeSpanMs(static_cast(AttachmentResult.ElapsedSeconds * 1000)), - NiceBytes(AttachmentSize)); - Info.AttachmentsDownloaded.fetch_add(1); if (RemoteResult.IsError()) { return; } + uint64_t AttachmentSize = BlobBuffer.GetSize(); + Info.AttachmentsDownloaded.fetch_add(1); Info.AttachmentBytesDownloaded.fetch_add(AttachmentSize); AttachmentsWriteLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&AttachmentsWriteLatch, - &RemoteResult, - &Info, - &ChunkStore, - RawHash, - AttachmentSize, - Bytes = std::move(AttachmentResult.Bytes), - OptionalContext]() { + Context.WorkerPool.ScheduleWork( + [&Context, &AttachmentsWriteLatch, &RemoteResult, &Info, RawHash, AttachmentSize, Bytes = std::move(BlobBuffer)]() { ZEN_TRACE_CPU("WriteAttachment"); auto _ = MakeGuard([&AttachmentsWriteLatch] { AttachmentsWriteLatch.CountDown(); }); @@ -1141,7 +1168,7 @@ namespace remotestore_impl { } try { - CidStore::InsertResult InsertResult = ChunkStore.AddChunk(Bytes, RawHash); + CidStore::InsertResult InsertResult = Context.ChunkStore.AddChunk(Bytes, RawHash); if (InsertResult.New) { Info.AttachmentBytesStored.fetch_add(AttachmentSize); @@ -1557,7 +1584,9 @@ namespace remotestore_impl { uint64_t PartialTransferWallTimeMS = Timer.GetElapsedTimeMs(); ReportProgress(OptionalContext, "Saving attachments"sv, - fmt::format("{} remaining... {}", Remaining, GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), + fmt::format("{} remaining... {}", + Remaining, + GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, PartialTransferWallTimeMS)), AttachmentsToSave, Remaining); } @@ -1566,7 +1595,7 @@ namespace remotestore_impl { { ReportProgress(OptionalContext, "Saving attachments"sv, - fmt::format("{}", GetStats(RemoteStore.GetStats(), ElapsedTimeMS)), + fmt::format("{}", GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS)), AttachmentsToSave, 0); } @@ -1577,7 +1606,7 @@ namespace remotestore_impl { LargeAttachmentCountToUpload, BulkAttachmentCountToUpload, NiceTimeSpanMs(ElapsedTimeMS), - GetStats(RemoteStore.GetStats(), ElapsedTimeMS))); + GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, ElapsedTimeMS))); } } // namespace remotestore_impl @@ -2186,31 +2215,36 @@ BuildContainer(CidStore& ChunkStore, } ResolveAttachmentsLatch.CountDown(); - while (!ResolveAttachmentsLatch.Wait(1000)) { - ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + ptrdiff_t AttachmentCountToUseForProgress = ResolveAttachmentsLatch.Remaining(); + while (!ResolveAttachmentsLatch.Wait(1000)) { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); - while (!ResolveAttachmentsLatch.Wait(1000)) + ptrdiff_t Remaining = ResolveAttachmentsLatch.Remaining(); + if (remotestore_impl::IsCancelled(OptionalContext)) { - Remaining = ResolveAttachmentsLatch.Remaining(); - remotestore_impl::ReportProgress(OptionalContext, - "Resolving attachments"sv, - fmt::format("Aborting, {} attachments remaining...", Remaining), - UploadAttachments.size(), - Remaining); + RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); + while (!ResolveAttachmentsLatch.Wait(1000)) + { + Remaining = ResolveAttachmentsLatch.Remaining(); + remotestore_impl::ReportProgress(OptionalContext, + "Resolving attachments"sv, + fmt::format("Aborting, {} attachments remaining...", Remaining), + UploadAttachments.size(), + Remaining); + } + remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0); + return {}; } - remotestore_impl::ReportProgress(OptionalContext, "Resolving attachments"sv, "Aborted"sv, UploadAttachments.size(), 0); - return {}; + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress(OptionalContext, + "Resolving attachments"sv, + fmt::format("{} remaining...", Remaining), + AttachmentCountToUseForProgress, + Remaining); } - remotestore_impl::ReportProgress(OptionalContext, - "Resolving attachments"sv, - fmt::format("{} remaining...", Remaining), - UploadAttachments.size(), - Remaining); } if (UploadAttachments.size() > 0) { @@ -2598,12 +2632,14 @@ BuildContainer(CidStore& ChunkStore, 0); } - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and in {}", - ChunkAssembleCount, - TotalOpCount, - GeneratedBlockCount, - NiceTimeSpanMs(static_cast(Timer.GetElapsedTimeMs())))); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Built oplog and collected {} attachments from {} ops into {} blocks and {} loose attachments in {}", + ChunkAssembleCount, + TotalOpCount, + GeneratedBlockCount, + LargeChunkHashes.size(), + NiceTimeSpanMs(static_cast(Timer.GetElapsedTimeMs())))); if (remotestore_impl::IsCancelled(OptionalContext)) { @@ -3155,17 +3191,18 @@ SaveOplog(CidStore& ChunkStore, remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats()); - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}", - RemoteStoreInfo.ContainerName, - RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", - NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), - NiceBytes(Info.OplogSizeBytes), - Info.AttachmentBlocksUploaded.load(), - NiceBytes(Info.AttachmentBlockBytesUploaded.load()), - Info.AttachmentsUploaded.load(), - NiceBytes(Info.AttachmentBytesUploaded.load()), - remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); + remotestore_impl::ReportMessage( + OptionalContext, + fmt::format("Saved oplog '{}' {} in {} ({}), Blocks: {} ({}), Attachments: {} ({}) {}", + RemoteStoreInfo.ContainerName, + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), + NiceBytes(Info.OplogSizeBytes), + Info.AttachmentBlocksUploaded.load(), + NiceBytes(Info.AttachmentBlockBytesUploaded.load()), + Info.AttachmentsUploaded.load(), + NiceBytes(Info.AttachmentBytesUploaded.load()), + remotestore_impl::GetStats(RemoteStore.GetStats(), /*OptionalCacheStats*/ nullptr, TransferWallTimeMS))); return Result; }; @@ -3234,6 +3271,11 @@ ParseOplogContainer( OpCount - OpsCompleteCount); } } + remotestore_impl::ReportProgress(OptionalContext, + "Scanning oplog"sv, + fmt::format("{} attachments found", NeededAttachments.size()), + OpCount, + OpCount - OpsCompleteCount); } { std::vector ReferencedAttachments(NeededAttachments.begin(), NeededAttachments.end()); @@ -3406,22 +3448,11 @@ SaveOplogContainer( } RemoteProjectStore::Result -LoadOplog(CidStore& ChunkStore, - RemoteProjectStore& RemoteStore, - ProjectStore::Oplog& Oplog, - WorkerThreadPool& NetworkWorkerPool, - WorkerThreadPool& WorkerPool, - bool ForceDownload, - bool IgnoreMissingAttachments, - bool CleanOplog, - EPartialBlockRequestMode PartialBlockRequestMode, - double HostLatencySec, - double CacheLatencySec, - JobContext* OptionalContext) +LoadOplog(LoadOplogContext&& Context) { using namespace std::literals; - std::unique_ptr LogOutput(std::make_unique(OptionalContext)); + std::unique_ptr LogOutput(std::make_unique(Context.OptionalJobContext)); remotestore_impl::DownloadInfo Info; @@ -3430,25 +3461,25 @@ LoadOplog(CidStore& ChunkStore, std::unordered_set Attachments; uint64_t BlockCountToDownload = 0; - RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = RemoteStore.GetInfo(); - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName)); + RemoteProjectStore::RemoteStoreInfo RemoteStoreInfo = Context.RemoteStore.GetInfo(); + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loading oplog container '{}'", RemoteStoreInfo.ContainerName)); uint64_t TransferWallTimeMS = 0; Stopwatch LoadContainerTimer; - RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer(); + RemoteProjectStore::LoadContainerResult LoadContainerResult = Context.RemoteStore.LoadContainer(); TransferWallTimeMS += LoadContainerTimer.GetElapsedTimeMs(); if (LoadContainerResult.ErrorCode) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Failed to load oplog container: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode)); return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode, .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, .Reason = LoadContainerResult.Reason, .Text = LoadContainerResult.Text}; } - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Loaded container in {} ({})", NiceTimeSpanMs(static_cast(LoadContainerResult.ElapsedSeconds * 1000)), NiceBytes(LoadContainerResult.ContainerObject.GetSize()))); @@ -3462,12 +3493,12 @@ LoadOplog(CidStore& ChunkStore, Stopwatch LoadAttachmentsTimer; std::atomic_uint64_t DownloadStartMS = (std::uint64_t)-1; - auto HasAttachment = [&Oplog, &ChunkStore, ForceDownload](const IoHash& RawHash) { - if (ForceDownload) + auto HasAttachment = [&Context](const IoHash& RawHash) { + if (Context.ForceDownload) { return false; } - if (ChunkStore.ContainsChunk(RawHash)) + if (Context.ChunkStore.ContainsChunk(RawHash)) { return true; } @@ -3482,10 +3513,7 @@ LoadOplog(CidStore& ChunkStore, std::vector NeededBlockDownloads; - auto OnNeedBlock = [&RemoteStore, - &ChunkStore, - &NetworkWorkerPool, - &WorkerPool, + auto OnNeedBlock = [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &AttachmentCount, @@ -3494,9 +3522,8 @@ LoadOplog(CidStore& ChunkStore, &Info, &LoadAttachmentsTimer, &DownloadStartMS, - &NeededBlockDownloads, - IgnoreMissingAttachments, - OptionalContext](ThinChunkBlockDescription&& ThinBlockDescription, std::vector&& NeededChunkIndexes) { + &NeededBlockDownloads](ThinChunkBlockDescription&& ThinBlockDescription, + std::vector&& NeededChunkIndexes) { if (RemoteResult.IsError()) { return; @@ -3506,12 +3533,7 @@ LoadOplog(CidStore& ChunkStore, AttachmentCount.fetch_add(1); if (ThinBlockDescription.BlockHash == IoHash::Zero) { - DownloadAndSaveBlockChunks(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSaveBlockChunks(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3528,11 +3550,7 @@ LoadOplog(CidStore& ChunkStore, } }; - auto OnNeedAttachment = [&RemoteStore, - &Oplog, - &ChunkStore, - &NetworkWorkerPool, - &WorkerPool, + auto OnNeedAttachment = [&Context, &AttachmentsDownloadLatch, &AttachmentsWriteLatch, &RemoteResult, @@ -3540,9 +3558,7 @@ LoadOplog(CidStore& ChunkStore, &AttachmentCount, &LoadAttachmentsTimer, &DownloadStartMS, - &Info, - IgnoreMissingAttachments, - OptionalContext](const IoHash& RawHash) { + &Info](const IoHash& RawHash) { if (!Attachments.insert(RawHash).second) { return; @@ -3552,12 +3568,7 @@ LoadOplog(CidStore& ChunkStore, return; } AttachmentCount.fetch_add(1); - DownloadAndSaveAttachment(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSaveAttachment(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3570,11 +3581,11 @@ LoadOplog(CidStore& ChunkStore, std::vector FilesToDechunk; auto OnChunkedAttachment = [&FilesToDechunk](const ChunkedInfo& Chunked) { FilesToDechunk.push_back(Chunked); }; - auto OnReferencedAttachments = [&Oplog](std::span RawHashes) { Oplog.CaptureAddedAttachments(RawHashes); }; + auto OnReferencedAttachments = [&Context](std::span RawHashes) { Context.Oplog.CaptureAddedAttachments(RawHashes); }; // Make sure we retain any attachments we download before writing the oplog - Oplog.EnableUpdateCapture(); - auto _ = MakeGuard([&Oplog]() { Oplog.DisableUpdateCapture(); }); + Context.Oplog.EnableUpdateCapture(); + auto _ = MakeGuard([&Context]() { Context.Oplog.DisableUpdateCapture(); }); CbObject OplogSection; RemoteProjectStore::Result Result = ParseOplogContainer(LoadContainerResult.ContainerObject, @@ -3584,12 +3595,12 @@ LoadOplog(CidStore& ChunkStore, OnNeedAttachment, OnChunkedAttachment, OplogSection, - OptionalContext); + Context.OptionalJobContext); if (Result.ErrorCode != 0) { RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); } - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Parsed oplog in {}, found {} attachments, {} blocks and {} chunked files to download", NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), Attachments.size(), @@ -3613,11 +3624,12 @@ LoadOplog(CidStore& ChunkStore, std::vector DownloadedViaLegacyChunkFlag(AllNeededChunkHashes.size(), false); ChunkBlockAnalyser::BlockResult PartialBlocksResult; - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Fetching descriptions for {} blocks", BlockHashes.size())); + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Fetching descriptions for {} blocks", BlockHashes.size())); - RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions = RemoteStore.GetBlockDescriptions(BlockHashes); + RemoteProjectStore::GetBlockDescriptionsResult BlockDescriptions = + Context.RemoteStore.GetBlockDescriptions(BlockHashes, Context.OptionalCache, Context.CacheBuildId); - remotestore_impl::ReportMessage(OptionalContext, + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("GetBlockDescriptions took {}. Found {} blocks", NiceTimeSpanMs(uint64_t(BlockDescriptions.ElapsedSeconds * 1000)), BlockDescriptions.Blocks.size())); @@ -3636,12 +3648,7 @@ LoadOplog(CidStore& ChunkStore, if (FindIt == BlockDescriptions.Blocks.end()) { // Fall back to full download as we can't get enough information about the block - DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSaveBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3678,56 +3685,86 @@ LoadOplog(CidStore& ChunkStore, if (!AllNeededChunkHashes.empty()) { std::vector PartialBlockDownloadModes; - std::vector BlockExistsInCache; + std::vector BlockExistsInCache(BlocksWithDescription.size(), false); - if (PartialBlockRequestMode == EPartialBlockRequestMode::Off) + if (Context.PartialBlockRequestMode == EPartialBlockRequestMode::Off) { PartialBlockDownloadModes.resize(BlocksWithDescription.size(), ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); } else { - RemoteProjectStore::AttachmentExistsInCacheResult CacheExistsResult = - RemoteStore.AttachmentExistsInCache(BlocksWithDescription); - if (CacheExistsResult.ErrorCode != 0 || CacheExistsResult.HasBody.size() != BlocksWithDescription.size()) + if (Context.OptionalCache) { - BlockExistsInCache.resize(BlocksWithDescription.size(), false); + std::vector CacheExistsResult = + Context.OptionalCache->BlobsExists(Context.CacheBuildId, BlocksWithDescription); + if (CacheExistsResult.size() == BlocksWithDescription.size()) + { + for (size_t BlobIndex = 0; BlobIndex < CacheExistsResult.size(); BlobIndex++) + { + BlockExistsInCache[BlobIndex] = CacheExistsResult[BlobIndex].HasBody; + } + } + uint64_t FoundBlocks = + std::accumulate(BlockExistsInCache.begin(), + BlockExistsInCache.end(), + uint64_t(0u), + [](uint64_t Current, bool Exists) -> uint64_t { return Current + (Exists ? 1 : 0); }); + if (FoundBlocks > 0) + { + remotestore_impl::ReportMessage( + Context.OptionalJobContext, + fmt::format("Found {} out of {} blocks in cache", FoundBlocks, BlockExistsInCache.size())); + } } - else + + ChunkBlockAnalyser::EPartialBlockDownloadMode CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + ChunkBlockAnalyser::EPartialBlockDownloadMode CachePartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + + switch (Context.PartialBlockRequestMode) { - BlockExistsInCache = std::move(CacheExistsResult.HasBody); + case EPartialBlockRequestMode::Off: + break; + case EPartialBlockRequestMode::ZenCacheOnly: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::Off; + break; + case EPartialBlockRequestMode::Mixed: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; + case EPartialBlockRequestMode::All: + CachePartialDownloadMode = Context.CacheMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed + : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange; + CloudPartialDownloadMode = Context.StoreMaxRangeCountPerRequest > 1 + ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange + : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange; + break; } PartialBlockDownloadModes.reserve(BlocksWithDescription.size()); - - for (bool ExistsInCache : BlockExistsInCache) + for (uint32_t BlockIndex = 0; BlockIndex < BlocksWithDescription.size(); BlockIndex++) { - if (PartialBlockRequestMode == EPartialBlockRequestMode::All) - { - PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact - : ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRange); - } - else if (PartialBlockRequestMode == EPartialBlockRequestMode::ZenCacheOnly) - { - PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::Exact - : ChunkBlockAnalyser::EPartialBlockDownloadMode::Off); - } - else if (PartialBlockRequestMode == EPartialBlockRequestMode::Mixed) - { - PartialBlockDownloadModes.push_back(ExistsInCache ? ChunkBlockAnalyser::EPartialBlockDownloadMode::MultiRangeHighSpeed - : ChunkBlockAnalyser::EPartialBlockDownloadMode::SingleRange); - } + const bool BlockExistInCache = BlockExistsInCache[BlockIndex]; + PartialBlockDownloadModes.push_back(BlockExistInCache ? CachePartialDownloadMode : CloudPartialDownloadMode); } } ZEN_ASSERT(PartialBlockDownloadModes.size() == BlocksWithDescription.size()); + ChunkBlockAnalyser PartialAnalyser( *LogOutput, BlockDescriptions.Blocks, - ChunkBlockAnalyser::Options{.IsQuiet = false, - .IsVerbose = false, - .HostLatencySec = HostLatencySec, - .HostHighSpeedLatencySec = CacheLatencySec, - .HostMaxRangeCountPerRequest = RemoteProjectStore::MaxRangeCountPerRequest}); + ChunkBlockAnalyser::Options{.IsQuiet = false, + .IsVerbose = false, + .HostLatencySec = Context.StoreLatencySec, + .HostHighSpeedLatencySec = Context.CacheLatencySec, + .HostMaxRangeCountPerRequest = Context.StoreMaxRangeCountPerRequest, + .HostHighSpeedMaxRangeCountPerRequest = Context.CacheMaxRangeCountPerRequest}); std::vector NeededBlocks = PartialAnalyser.GetNeeded(AllNeededPartialChunkHashesLookup, @@ -3736,12 +3773,7 @@ LoadOplog(CidStore& ChunkStore, PartialBlocksResult = PartialAnalyser.CalculatePartialBlockDownloads(NeededBlocks, PartialBlockDownloadModes); for (uint32_t FullBlockIndex : PartialBlocksResult.FullBlockIndexes) { - DownloadAndSaveBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSaveBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3765,12 +3797,7 @@ LoadOplog(CidStore& ChunkStore, RangeCount++; } - DownloadAndSavePartialBlock(ChunkStore, - RemoteStore, - IgnoreMissingAttachments, - OptionalContext, - NetworkWorkerPool, - WorkerPool, + DownloadAndSavePartialBlock(Context, AttachmentsDownloadLatch, AttachmentsWriteLatch, RemoteResult, @@ -3791,38 +3818,44 @@ LoadOplog(CidStore& ChunkStore, } AttachmentsDownloadLatch.CountDown(); - while (!AttachmentsDownloadLatch.Wait(1000)) { - ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + ptrdiff_t AttachmentCountToUseForProgress = AttachmentsDownloadLatch.Remaining(); + while (!AttachmentsDownloadLatch.Wait(1000)) { - if (!RemoteResult.IsError()) + ptrdiff_t Remaining = AttachmentsDownloadLatch.Remaining(); + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); + if (!RemoteResult.IsError()) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); + } + } + uint64_t PartialTransferWallTimeMS = TransferWallTimeMS; + if (DownloadStartMS != (uint64_t)-1) + { + PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load(); } - } - uint64_t PartialTransferWallTimeMS = TransferWallTimeMS; - if (DownloadStartMS != (uint64_t)-1) - { - PartialTransferWallTimeMS += LoadAttachmentsTimer.GetElapsedTimeMs() - DownloadStartMS.load(); - } - - uint64_t AttachmentsDownloaded = - Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); - uint64_t AttachmentBytesDownloaded = Info.AttachmentBlockBytesDownloaded.load() + Info.AttachmentBlockRangeBytesDownloaded.load() + - Info.AttachmentBytesDownloaded.load(); - remotestore_impl::ReportProgress(OptionalContext, - "Loading attachments"sv, - fmt::format("{} ({}) downloaded, {} ({}) stored, {} remaining. {}", - AttachmentsDownloaded, - NiceBytes(AttachmentBytesDownloaded), - Info.AttachmentsStored.load(), - NiceBytes(Info.AttachmentBytesStored.load()), - Remaining, - remotestore_impl::GetStats(RemoteStore.GetStats(), PartialTransferWallTimeMS)), - AttachmentCount.load(), - Remaining); + uint64_t AttachmentsDownloaded = + Info.AttachmentBlocksDownloaded.load() + Info.AttachmentBlocksRangesDownloaded.load() + Info.AttachmentsDownloaded.load(); + uint64_t AttachmentBytesDownloaded = Info.AttachmentBlockBytesDownloaded.load() + + Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load(); + + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress( + Context.OptionalJobContext, + "Loading attachments"sv, + fmt::format( + "{} ({}) downloaded, {} ({}) stored, {} remaining. {}", + AttachmentsDownloaded, + NiceBytes(AttachmentBytesDownloaded), + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Remaining, + remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, PartialTransferWallTimeMS)), + AttachmentCountToUseForProgress, + Remaining); + } } if (DownloadStartMS != (uint64_t)-1) { @@ -3831,58 +3864,58 @@ LoadOplog(CidStore& ChunkStore, if (AttachmentCount.load() > 0) { - remotestore_impl::ReportProgress(OptionalContext, - "Loading attachments"sv, - fmt::format("{}", remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS)), - AttachmentCount.load(), - 0); + remotestore_impl::ReportProgress( + Context.OptionalJobContext, + "Loading attachments"sv, + fmt::format("{}", remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS)), + AttachmentCount.load(), + 0); } AttachmentsWriteLatch.CountDown(); - while (!AttachmentsWriteLatch.Wait(1000)) { - ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + ptrdiff_t AttachmentCountToUseForProgress = AttachmentsWriteLatch.Remaining(); + while (!AttachmentsWriteLatch.Wait(1000)) { - if (!RemoteResult.IsError()) + ptrdiff_t Remaining = AttachmentsWriteLatch.Remaining(); + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) { - RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); + if (!RemoteResult.IsError()) + { + RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); + } } + AttachmentCountToUseForProgress = Max(Remaining, AttachmentCountToUseForProgress); + remotestore_impl::ReportProgress(Context.OptionalJobContext, + "Writing attachments"sv, + fmt::format("{} ({}), {} remaining.", + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Remaining), + AttachmentCountToUseForProgress, + Remaining); } - remotestore_impl::ReportProgress( - OptionalContext, - "Writing attachments"sv, - fmt::format("{} ({}), {} remaining.", Info.AttachmentsStored.load(), NiceBytes(Info.AttachmentBytesStored.load()), Remaining), - AttachmentCount.load(), - Remaining); } if (AttachmentCount.load() > 0) { - remotestore_impl::ReportProgress(OptionalContext, "Writing attachments", ""sv, AttachmentCount.load(), 0); + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Writing attachments", ""sv, AttachmentCount.load(), 0); } if (Result.ErrorCode == 0) { if (!FilesToDechunk.empty()) { - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size())); + remotestore_impl::ReportMessage(Context.OptionalJobContext, fmt::format("Dechunking {} attachments", FilesToDechunk.size())); Latch DechunkLatch(1); - std::filesystem::path TempFilePath = Oplog.TempPath(); + std::filesystem::path TempFilePath = Context.Oplog.TempPath(); for (const ChunkedInfo& Chunked : FilesToDechunk) { std::filesystem::path TempFileName = TempFilePath / Chunked.RawHash.ToHexString(); DechunkLatch.AddCount(1); - WorkerPool.ScheduleWork( - [&ChunkStore, - &DechunkLatch, - TempFileName, - &Chunked, - &RemoteResult, - IgnoreMissingAttachments, - &Info, - OptionalContext]() { + Context.WorkerPool.ScheduleWork( + [&Context, &DechunkLatch, TempFileName, &Chunked, &RemoteResult, &Info]() { ZEN_TRACE_CPU("DechunkAttachment"); auto _ = MakeGuard([&DechunkLatch, &TempFileName] { @@ -3916,16 +3949,16 @@ LoadOplog(CidStore& ChunkStore, for (std::uint32_t SequenceIndex : Chunked.ChunkSequence) { const IoHash& ChunkHash = Chunked.ChunkHashes[SequenceIndex]; - IoBuffer Chunk = ChunkStore.FindChunkByCid(ChunkHash); + IoBuffer Chunk = Context.ChunkStore.FindChunkByCid(ChunkHash); if (!Chunk) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Missing chunk {} for chunked attachment {}", ChunkHash, Chunked.RawHash)); // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow(HttpResponseCode::NotFound), @@ -3943,7 +3976,7 @@ LoadOplog(CidStore& ChunkStore, if (RawHash != ChunkHash) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Mismatching raw hash {} for chunk {} for chunked attachment {}", RawHash, ChunkHash, @@ -3951,7 +3984,7 @@ LoadOplog(CidStore& ChunkStore, // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow(HttpResponseCode::NotFound), @@ -3988,14 +4021,14 @@ LoadOplog(CidStore& ChunkStore, })) { remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Failed to decompress chunk {} for chunked attachment {}", ChunkHash, Chunked.RawHash)); // We only add 1 as the resulting missing count will be 1 for the dechunked file Info.MissingAttachmentCount.fetch_add(1); - if (!IgnoreMissingAttachments) + if (!Context.IgnoreMissingAttachments) { RemoteResult.SetError( gsl::narrow(HttpResponseCode::NotFound), @@ -4019,18 +4052,17 @@ LoadOplog(CidStore& ChunkStore, } uint64_t TmpBufferSize = TmpBuffer.GetSize(); CidStore::InsertResult InsertResult = - ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace); + Context.ChunkStore.AddChunk(TmpBuffer, Chunked.RawHash, CidStore::InsertMode::kMayBeMovedInPlace); if (InsertResult.New) { Info.AttachmentBytesStored.fetch_add(TmpBufferSize); Info.AttachmentsStored.fetch_add(1); } - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Dechunked attachment {} ({}) in {}", - Chunked.RawHash, - NiceBytes(Chunked.RawSize), - NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); + ZEN_INFO("Dechunked attachment {} ({}) in {}", + Chunked.RawHash, + NiceBytes(Chunked.RawSize), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); } catch (const std::exception& Ex) { @@ -4046,54 +4078,58 @@ LoadOplog(CidStore& ChunkStore, while (!DechunkLatch.Wait(1000)) { ptrdiff_t Remaining = DechunkLatch.Remaining(); - if (remotestore_impl::IsCancelled(OptionalContext)) + if (remotestore_impl::IsCancelled(Context.OptionalJobContext)) { if (!RemoteResult.IsError()) { RemoteResult.SetError(gsl::narrow(HttpResponseCode::OK), "Operation cancelled", ""); remotestore_impl::ReportMessage( - OptionalContext, + Context.OptionalJobContext, fmt::format("Aborting ({}): {}", RemoteResult.GetError(), RemoteResult.GetErrorReason())); } } - remotestore_impl::ReportProgress(OptionalContext, + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Dechunking attachments"sv, fmt::format("{} remaining...", Remaining), FilesToDechunk.size(), Remaining); } - remotestore_impl::ReportProgress(OptionalContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0); + remotestore_impl::ReportProgress(Context.OptionalJobContext, "Dechunking attachments"sv, ""sv, FilesToDechunk.size(), 0); } Result = RemoteResult.ConvertResult(); } if (Result.ErrorCode == 0) { - if (CleanOplog) + if (Context.CleanOplog) { - RemoteStore.Flush(); - if (!Oplog.Reset()) + if (Context.OptionalCache) + { + Context.OptionalCache->Flush(100, [](intptr_t) { return /*DontWaitForPendingOperation*/ false; }); + } + if (!Context.Oplog.Reset()) { Result = RemoteProjectStore::Result{.ErrorCode = gsl::narrow(HttpResponseCode::InternalServerError), .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0, - .Reason = fmt::format("Failed to clean existing oplog '{}'", Oplog.OplogId())}; - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason)); + .Reason = fmt::format("Failed to clean existing oplog '{}'", Context.Oplog.OplogId())}; + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("Aborting ({}): {}", Result.ErrorCode, Result.Reason)); } } if (Result.ErrorCode == 0) { - remotestore_impl::WriteOplogSection(Oplog, OplogSection, OptionalContext); + remotestore_impl::WriteOplogSection(Context.Oplog, OplogSection, Context.OptionalJobContext); } } Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.0; - remotestore_impl::LogRemoteStoreStatsDetails(RemoteStore.GetStats()); + remotestore_impl::LogRemoteStoreStatsDetails(Context.RemoteStore.GetStats()); { std::string DownloadDetails; RemoteProjectStore::ExtendedStats ExtendedStats; - if (RemoteStore.GetExtendedStats(ExtendedStats)) + if (Context.RemoteStore.GetExtendedStats(ExtendedStats)) { if (!ExtendedStats.m_ReceivedBytesPerSource.empty()) { @@ -4112,7 +4148,8 @@ LoadOplog(CidStore& ChunkStore, Total += It.second; } - remotestore_impl::ReportMessage(OptionalContext, fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView())); + remotestore_impl::ReportMessage(Context.OptionalJobContext, + fmt::format("Downloaded {} ({})", NiceBytes(Total), SB.ToView())); } } } @@ -4122,25 +4159,26 @@ LoadOplog(CidStore& ChunkStore, uint64_t TotalBytesDownloaded = Info.OplogSizeBytes + Info.AttachmentBlockBytesDownloaded.load() + Info.AttachmentBlockRangeBytesDownloaded.load() + Info.AttachmentBytesDownloaded.load(); - remotestore_impl::ReportMessage(OptionalContext, - fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), BlockRanges: {} ({}), Attachments: {} " - "({}), Total: {} ({}), Stored: {} ({}), Missing: {} {}", - RemoteStoreInfo.ContainerName, - Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE", - NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), - NiceBytes(Info.OplogSizeBytes), - Info.AttachmentBlocksDownloaded.load(), - NiceBytes(Info.AttachmentBlockBytesDownloaded.load()), - Info.AttachmentBlocksRangesDownloaded.load(), - NiceBytes(Info.AttachmentBlockRangeBytesDownloaded.load()), - Info.AttachmentsDownloaded.load(), - NiceBytes(Info.AttachmentBytesDownloaded.load()), - TotalDownloads, - NiceBytes(TotalBytesDownloaded), - Info.AttachmentsStored.load(), - NiceBytes(Info.AttachmentBytesStored.load()), - Info.MissingAttachmentCount.load(), - remotestore_impl::GetStats(RemoteStore.GetStats(), TransferWallTimeMS))); + remotestore_impl::ReportMessage( + Context.OptionalJobContext, + fmt::format("Loaded oplog '{}' {} in {} ({}), Blocks: {} ({}), BlockRanges: {} ({}), Attachments: {} " + "({}), Total: {} ({}), Stored: {} ({}), Missing: {} {}", + RemoteStoreInfo.ContainerName, + Result.ErrorCode == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast(Result.ElapsedSeconds * 1000.0)), + NiceBytes(Info.OplogSizeBytes), + Info.AttachmentBlocksDownloaded.load(), + NiceBytes(Info.AttachmentBlockBytesDownloaded.load()), + Info.AttachmentBlocksRangesDownloaded.load(), + NiceBytes(Info.AttachmentBlockRangeBytesDownloaded.load()), + Info.AttachmentsDownloaded.load(), + NiceBytes(Info.AttachmentBytesDownloaded.load()), + TotalDownloads, + NiceBytes(TotalBytesDownloaded), + Info.AttachmentsStored.load(), + NiceBytes(Info.AttachmentBytesStored.load()), + Info.MissingAttachmentCount.load(), + remotestore_impl::GetStats(Context.RemoteStore.GetStats(), Context.OptionalCacheStats, TransferWallTimeMS))); return Result; } @@ -4237,6 +4275,26 @@ namespace projectstore_testutils { return Result; } + class TestJobContext : public JobContext + { + public: + explicit TestJobContext(int& OpIndex) : m_OpIndex(OpIndex) {} + virtual bool IsCancelled() const { return false; } + virtual void ReportMessage(std::string_view Message) { ZEN_INFO("Job {}: {}", m_OpIndex, Message); } + virtual void ReportProgress(std::string_view CurrentOp, std::string_view Details, ptrdiff_t TotalCount, ptrdiff_t RemainingCount) + { + ZEN_INFO("Job {}: Op '{}'{} {}/{}", + m_OpIndex, + CurrentOp, + Details.empty() ? "" : fmt::format(" {}", Details), + TotalCount - RemainingCount, + TotalCount); + } + + private: + int& m_OpIndex; + }; + } // namespace projectstore_testutils TEST_SUITE_BEGIN("remotestore.projectstore"); @@ -4334,66 +4392,708 @@ TEST_CASE_TEMPLATE("project.store.export", false, nullptr); - CHECK(ExportResult.ErrorCode == 0); + REQUIRE(ExportResult.ErrorCode == 0); Ref OplogImport = Project->NewOplog("oplog2", {}); CHECK(OplogImport); - RemoteProjectStore::Result ImportResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ false, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ false, - EPartialBlockRequestMode::Mixed, - /*HostLatencySec*/ -1.0, - /*CacheLatencySec*/ -1.0, - nullptr); + int OpJobIndex = 0; + TestJobContext OpJobContext(OpJobIndex); + + RemoteProjectStore::Result ImportResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportForceResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ true, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ false, - EPartialBlockRequestMode::Mixed, - /*HostLatencySec*/ -1.0, - /*CacheLatencySec*/ -1.0, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportForceResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = true, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportForceResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportCleanResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ false, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ true, - EPartialBlockRequestMode::Mixed, - /*HostLatencySec*/ -1.0, - /*CacheLatencySec*/ -1.0, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportCleanResult = LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = true, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportCleanResult.ErrorCode == 0); - - RemoteProjectStore::Result ImportForceCleanResult = LoadOplog(CidStore, - *RemoteStore, - *OplogImport, - NetworkPool, - WorkerPool, - /*Force*/ true, - /*IgnoreMissingAttachments*/ false, - /*CleanOplog*/ true, - EPartialBlockRequestMode::Mixed, - /*HostLatencySec*/ -1.0, - /*CacheLatencySec*/ -1.0, - nullptr); + OpJobIndex++; + + RemoteProjectStore::Result ImportForceCleanResult = + LoadOplog(LoadOplogContext{.ChunkStore = CidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = Oid::Zero, + .Oplog = *OplogImport, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = true, + .IgnoreMissingAttachments = false, + .CleanOplog = true, + .PartialBlockRequestMode = EPartialBlockRequestMode::Mixed, + .OptionalJobContext = &OpJobContext}); CHECK(ImportForceCleanResult.ErrorCode == 0); + OpJobIndex++; +} + +// Common oplog setup used by the two tests below. +// Returns a FileRemoteStore backed by ExportDir that has been populated with a SaveOplog call. +// Keeps the test data identical to project.store.export so the two test suites exercise the same blocks/attachments. +static RemoteProjectStore::Result +SetupExportStore(CidStore& CidStore, + ProjectStore::Project& Project, + WorkerThreadPool& NetworkPool, + WorkerThreadPool& WorkerPool, + const std::filesystem::path& ExportDir, + std::shared_ptr& OutRemoteStore) +{ + using namespace projectstore_testutils; + using namespace std::literals; + + Ref Oplog = Project.NewOplog("oplog_export", {}); + if (!Oplog) + { + return RemoteProjectStore::Result{.ErrorCode = -1}; + } + + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list{77}))); + Oplog->AppendNewOplogEntry( + CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list{7123, 583, 690, 99}))); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list{55, 122}))); + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage( + Oid::NewOid(), + CreateAttachments(std::initializer_list{256u * 1024u, 92u * 1024u}, OodleCompressionLevel::None))); + + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 64u * 1024, + .MaxChunksPerBlock = 1000, + .MaxChunkEmbedSize = 32 * 1024u, + .ChunkFileSizeLimit = 64u * 1024u}, + /*.FolderPath =*/ExportDir, + /*.Name =*/std::string("oplog_export"), + /*.OptionalBaseName =*/std::string(), + /*.ForceDisableBlocks =*/false, + /*.ForceEnableTempBlocks =*/false}; + + OutRemoteStore = CreateFileRemoteStore(Log(), Options); + return SaveOplog(CidStore, + *OutRemoteStore, + Project, + *Oplog, + NetworkPool, + WorkerPool, + Options.MaxBlockSize, + Options.MaxChunksPerBlock, + Options.MaxChunkEmbedSize, + Options.ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); +} + +// Creates an export store with a single oplog entry that packs six 512 KB chunks into one +// ~3 MB block (MaxBlockSize = 8 MB). The resulting block slack (~1.5 MB) far exceeds the +// 512 KB threshold that ChunkBlockAnalyser requires before it will consider partial-block +// downloads instead of full-block downloads. +// +// This function is self-contained: it creates its own GcManager, CidStore, ProjectStore and +// Project internally so that each call is independent of any outer test context. After +// SaveOplog returns, all persistent data lives on disk inside ExportDir and the caller can +// freely query OutRemoteStore without holding any references to the internal context. +static RemoteProjectStore::Result +SetupPartialBlockExportStore(WorkerThreadPool& NetworkPool, + WorkerThreadPool& WorkerPool, + const std::filesystem::path& ExportDir, + std::shared_ptr& OutRemoteStore) +{ + using namespace projectstore_testutils; + using namespace std::literals; + + // Self-contained CAS and project store. Subdirectories of ExportDir keep everything + // together without relying on the outer TEST_CASE's ExportCidStore / ExportProject. + GcManager LocalGc; + CidStore LocalCidStore(LocalGc); + CidStoreConfiguration LocalCidConfig = {.RootDirectory = ExportDir / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + LocalCidStore.Initialize(LocalCidConfig); + + std::filesystem::path LocalProjectBasePath = ExportDir / "proj"; + ProjectStore LocalProjectStore(LocalCidStore, LocalProjectBasePath, LocalGc, ProjectStore::Configuration{}); + Ref LocalProject(LocalProjectStore.NewProject(LocalProjectBasePath / "p"sv, + "p"sv, + (ExportDir / "root").string(), + (ExportDir / "engine").string(), + (ExportDir / "game").string(), + (ExportDir / "game" / "game.uproject").string())); + + Ref Oplog = LocalProject->NewOplog("oplog_partial_block", {}); + if (!Oplog) + { + return RemoteProjectStore::Result{.ErrorCode = -1}; + } + + // Six 512 KB chunks with OodleCompressionLevel::None so the compressed size stays large + // and the block genuinely exceeds the 512 KB slack threshold. + Oplog->AppendNewOplogEntry(CreateBulkDataOplogPackage( + Oid::NewOid(), + CreateAttachments(std::initializer_list{512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u, 512u * 1024u}, + OodleCompressionLevel::None))); + + // MaxChunkEmbedSize must be larger than the compressed size of each 512 KB chunk + // (OodleCompressionLevel::None → compressed ≈ raw ≈ 512 KB). With the legacy + // 32 KB limit all six chunks would become loose large attachments and no block would + // be created, so we use the production default of 1.5 MB instead. + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = 8u * 1024u * 1024u, + .MaxChunksPerBlock = 1000, + .MaxChunkEmbedSize = RemoteStoreOptions::DefaultMaxChunkEmbedSize, + .ChunkFileSizeLimit = 64u * 1024u * 1024u}, + /*.FolderPath =*/ExportDir, + /*.Name =*/std::string("oplog_partial_block"), + /*.OptionalBaseName =*/std::string(), + /*.ForceDisableBlocks =*/false, + /*.ForceEnableTempBlocks =*/false}; + OutRemoteStore = CreateFileRemoteStore(Log(), Options); + return SaveOplog(LocalCidStore, + *OutRemoteStore, + *LocalProject, + *Oplog, + NetworkPool, + WorkerPool, + Options.MaxBlockSize, + Options.MaxChunksPerBlock, + Options.MaxChunkEmbedSize, + Options.ChunkFileSizeLimit, + /*EmbedLooseFiles*/ true, + /*ForceUpload*/ false, + /*IgnoreMissingAttachments*/ false, + /*OptionalContext*/ nullptr); +} + +// Returns the first block hash that has at least MinChunkCount chunks, or a zero IoHash +// if no qualifying block exists in Store. +static IoHash +FindBlockWithMultipleChunks(RemoteProjectStore& Store, size_t MinChunkCount) +{ + RemoteProjectStore::LoadContainerResult ContainerResult = Store.LoadContainer(); + if (ContainerResult.ErrorCode != 0) + { + return {}; + } + std::vector BlockHashes = GetBlockHashesFromOplog(ContainerResult.ContainerObject); + if (BlockHashes.empty()) + { + return {}; + } + RemoteProjectStore::GetBlockDescriptionsResult Descriptions = Store.GetBlockDescriptions(BlockHashes, nullptr, Oid{}); + if (Descriptions.ErrorCode != 0) + { + return {}; + } + for (const ChunkBlockDescription& Desc : Descriptions.Blocks) + { + if (Desc.ChunkRawHashes.size() >= MinChunkCount) + { + return Desc.BlockHash; + } + } + return {}; +} + +// Loads BlockHash from Source and inserts every even-indexed chunk (0, 2, 4, …) into +// TargetCidStore. Odd-indexed chunks are left absent so that when an import is run +// against the same block, HasAttachment returns false for three non-adjacent positions +// — the minimum needed to exercise the multi-range partial-block download paths. +static void +SeedCidStoreWithAlternateChunks(CidStore& TargetCidStore, RemoteProjectStore& Source, const IoHash& BlockHash) +{ + RemoteProjectStore::LoadAttachmentResult BlockResult = Source.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode != 0 || !BlockResult.Bytes) + { + return; + } + + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(BlockResult.Bytes), RawHash, RawSize); + if (!Compressed) + { + return; + } + CompositeBuffer BlockPayload = Compressed.DecompressToComposite(); + if (!BlockPayload) + { + return; + } + + uint32_t ChunkIndex = 0; + uint64_t HeaderSize = 0; + IterateChunkBlock( + BlockPayload.Flatten(), + [&TargetCidStore, &ChunkIndex](CompressedBuffer&& Chunk, const IoHash& AttachmentHash) { + if (ChunkIndex % 2 == 0) + { + IoBuffer ChunkData = Chunk.GetCompressed().Flatten().AsIoBuffer(); + TargetCidStore.AddChunk(ChunkData, AttachmentHash); + } + ++ChunkIndex; + }, + HeaderSize); +} + +TEST_CASE("project.store.import.context_settings") +{ + using namespace std::literals; + using namespace projectstore_testutils; + + ScopedTemporaryDirectory TempDir; + ScopedTemporaryDirectory ExportDir; + + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + // Export-side CAS and project store: used only by SetupExportStore to build the remote store + // payload. Kept separate from the import side so the two CAS instances are disjoint. + GcManager ExportGc; + CidStore ExportCidStore(ExportGc); + CidStoreConfiguration ExportCidConfig = {.RootDirectory = TempDir.Path() / "export_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ExportCidStore.Initialize(ExportCidConfig); + + std::filesystem::path ExportBasePath = TempDir.Path() / "export_projectstore"; + ProjectStore ExportProjectStore(ExportCidStore, ExportBasePath, ExportGc, ProjectStore::Configuration{}); + Ref ExportProject(ExportProjectStore.NewProject(ExportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + uint32_t NetworkWorkerCount = Max(GetHardwareConcurrency() / 4u, 2u); + uint32_t WorkerCount = (NetworkWorkerCount < GetHardwareConcurrency()) ? Max(GetHardwareConcurrency() - NetworkWorkerCount, 4u) : 4u; + WorkerThreadPool WorkerPool(WorkerCount); + WorkerThreadPool NetworkPool(NetworkWorkerCount); + + std::shared_ptr RemoteStore; + RemoteProjectStore::Result ExportResult = + SetupExportStore(ExportCidStore, *ExportProject, NetworkPool, WorkerPool, ExportDir.Path(), RemoteStore); + REQUIRE(ExportResult.ErrorCode == 0); + + // Import-side CAS and project store: starts empty, mirroring a fresh machine that has never + // downloaded the data. HasAttachment() therefore returns false for every chunk, so the import + // genuinely contacts the remote store without needing ForceDownload on the populate pass. + GcManager ImportGc; + CidStore ImportCidStore(ImportGc); + CidStoreConfiguration ImportCidConfig = {.RootDirectory = TempDir.Path() / "import_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + ImportCidStore.Initialize(ImportCidConfig); + + std::filesystem::path ImportBasePath = TempDir.Path() / "import_projectstore"; + ProjectStore ImportProjectStore(ImportCidStore, ImportBasePath, ImportGc, ProjectStore::Configuration{}); + Ref ImportProject(ImportProjectStore.NewProject(ImportBasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + + const Oid CacheBuildId = Oid::NewOid(); + BuildStorageCache::Statistics CacheStats; + std::unique_ptr Cache = CreateInMemoryBuildStorageCache(256u, CacheStats); + auto ResetCacheStats = [&]() { + CacheStats.TotalBytesRead = 0; + CacheStats.TotalBytesWritten = 0; + CacheStats.TotalRequestCount = 0; + CacheStats.TotalRequestTimeUs = 0; + CacheStats.TotalExecutionTimeUs = 0; + CacheStats.PeakSentBytes = 0; + CacheStats.PeakReceivedBytes = 0; + CacheStats.PeakBytesPerSec = 0; + CacheStats.PutBlobCount = 0; + CacheStats.PutBlobByteCount = 0; + }; + + int OpJobIndex = 0; + + TestJobContext OpJobContext(OpJobIndex); + + // Helper: run a LoadOplog against the import-side CAS/project with the given context knobs. + // Each call creates a fresh oplog so repeated calls within one SUBCASE don't short-circuit on + // already-present data. + auto DoImport = [&](BuildStorageCache* OptCache, + EPartialBlockRequestMode Mode, + double StoreLatency, + uint64_t StoreRanges, + double CacheLatency, + uint64_t CacheRanges, + bool PopulateCache, + bool ForceDownload) -> RemoteProjectStore::Result { + Ref ImportOplog = ImportProject->NewOplog(fmt::format("import_{}", OpJobIndex++), {}); + return LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *RemoteStore, + .OptionalCache = OptCache, + .CacheBuildId = CacheBuildId, + .Oplog = *ImportOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = ForceDownload, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = Mode, + .PopulateCache = PopulateCache, + .StoreLatencySec = StoreLatency, + .StoreMaxRangeCountPerRequest = StoreRanges, + .CacheLatencySec = CacheLatency, + .CacheMaxRangeCountPerRequest = CacheRanges, + .OptionalJobContext = &OpJobContext}); + }; + + // Shorthand: Mode=All, low latency, 128 ranges for both store and cache. + auto ImportAll = [&](BuildStorageCache* OptCache, bool Populate, bool Force) { + return DoImport(OptCache, EPartialBlockRequestMode::All, 0.001, 128u, 0.001, 128u, Populate, Force); + }; + + SUBCASE("mode_off_no_cache") + { + // Baseline: no partial block requests, no cache. + RemoteProjectStore::Result R = + DoImport(nullptr, EPartialBlockRequestMode::Off, -1.0, (uint64_t)-1, -1.0, (uint64_t)-1, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_all_multirange_cloud_no_cache") + { + // StoreMaxRangeCountPerRequest > 1 → MultiRange cloud path. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 128u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_all_singlerange_cloud_no_cache") + { + // StoreMaxRangeCountPerRequest == 1 → SingleRange cloud path. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::All, 0.001, 1u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("mode_mixed_high_latency_no_cache") + { + // High store latency encourages range merging; Mixed uses SingleRange for cloud, Off for cache. + RemoteProjectStore::Result R = DoImport(nullptr, EPartialBlockRequestMode::Mixed, 0.1, 128u, -1.0, 0u, false, false); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("cache_populate_and_hit") + { + // First import: ImportCidStore is empty so all blocks are downloaded from the remote store + // and written to the cache. + RemoteProjectStore::Result PopulateResult = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(PopulateResult.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + + // Re-import with ForceDownload=true: all chunks are now in ImportCidStore but Force overrides + // HasAttachment() so the download logic re-runs and serves blocks from the cache instead of + // the remote store. + ResetCacheStats(); + RemoteProjectStore::Result HitResult = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true); + CHECK(HitResult.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount == 0); + // TotalRequestCount covers both full-blob cache hits and partial-range cache hits. + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("cache_no_populate_flag") + { + // Cache is provided but PopulateCache=false: blocks are downloaded to ImportCidStore but + // nothing should be written to the cache. + RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/false); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount == 0); + } + + SUBCASE("mode_zencacheonly_cache_multirange") + { + // Pre-populate the cache via a plain import, then re-import with ZenCacheOnly + + // CacheMaxRangeCountPerRequest=128. With 100% of chunks needed, all blocks go to + // FullBlockIndexes and GetBuildBlob (full blob) is called from the cache. + // CacheMaxRangeCountPerRequest > 1 would route partial downloads through GetBuildBlobRanges + // if the analyser ever emits BlockRanges entries. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 128u, false, true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("mode_zencacheonly_cache_singlerange") + { + // Pre-populate the cache, then re-import with ZenCacheOnly + CacheMaxRangeCountPerRequest=1. + // With 100% of chunks needed the analyser sends all blocks to FullBlockIndexes (full-block + // download path), which calls GetBuildBlob with no range offset — a full-blob cache hit. + // The single-range vs multi-range distinction only matters for the partial-block (BlockRanges) + // path, which is not reached when all chunks are needed. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = DoImport(Cache.get(), EPartialBlockRequestMode::ZenCacheOnly, 0.1, 128u, 0.001, 1u, false, true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("mode_all_cache_and_cloud_multirange") + { + // Pre-populate cache; All mode uses multi-range for both the cache and cloud paths. + RemoteProjectStore::Result Populate = ImportAll(Cache.get(), /*PopulateCache=*/true, /*Force=*/false); + CHECK(Populate.ErrorCode == 0); + ResetCacheStats(); + + RemoteProjectStore::Result R = ImportAll(Cache.get(), /*PopulateCache=*/false, /*Force=*/true); + CHECK(R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("partial_block_cloud_multirange") + { + // Export store with 6 × 512 KB chunks packed into one ~3 MB block. + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + // Seeding even-indexed chunks (0, 2, 4) leaves odd ones (1, 3, 5) absent in + // ImportCidStore. Three non-adjacent needed positions → three BlockRangeDescriptors. + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash); + + // StoreMaxRangeCountPerRequest=128 → all three ranges sent in one LoadAttachmentRanges call. + Ref PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_multi_{}", OpJobIndex++), {}); + RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = CacheBuildId, + .Oplog = *PartialOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = -1.0, + .CacheMaxRangeCountPerRequest = 0u, + .OptionalJobContext = &OpJobContext}); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("partial_block_cloud_singlerange") + { + // Same block layout as partial_block_cloud_multirange but StoreMaxRangeCountPerRequest=1. + // DownloadPartialBlock issues one LoadAttachmentRanges call per range. + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + SeedCidStoreWithAlternateChunks(ImportCidStore, *PartialRemoteStore, BlockHash); + + Ref PartialOplog = ImportProject->NewOplog(fmt::format("partial_cloud_single_{}", OpJobIndex++), {}); + RemoteProjectStore::Result R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = nullptr, + .CacheBuildId = CacheBuildId, + .Oplog = *PartialOplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 1u, + .CacheLatencySec = -1.0, + .CacheMaxRangeCountPerRequest = 0u, + .OptionalJobContext = &OpJobContext}); + CHECK(R.ErrorCode == 0); + } + + SUBCASE("partial_block_cache_multirange") + { + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + + // Phase 1: ImportCidStore starts empty → full block download from remote → PutBuildBlob + // populates the cache. + { + Ref Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p1_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase1Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = true, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase1R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + } + ResetCacheStats(); + + // Phase 2: fresh CidStore with only even-indexed chunks seeded. + // HasAttachment returns false for odd chunks (1, 3, 5) → three BlockRangeDescriptors. + // Block is in cache from Phase 1 → cache partial path. + // CacheMaxRangeCountPerRequest=128 → SubRangeCount=3 > 1 → GetBuildBlobRanges. + GcManager Phase2Gc; + CidStore Phase2CidStore(Phase2Gc); + CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + Phase2CidStore.Initialize(Phase2CidConfig); + SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); + + Ref Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_multi_p2_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase2Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase2R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } + + SUBCASE("partial_block_cache_singlerange") + { + ScopedTemporaryDirectory PartialExportDir; + std::shared_ptr PartialRemoteStore; + RemoteProjectStore::Result ExportR = + SetupPartialBlockExportStore(NetworkPool, WorkerPool, PartialExportDir.Path(), PartialRemoteStore); + REQUIRE(ExportR.ErrorCode == 0); + + IoHash BlockHash = FindBlockWithMultipleChunks(*PartialRemoteStore, 4u); + CHECK(BlockHash != IoHash::Zero); + + // Phase 1: full block download from remote into cache. + { + Ref Phase1Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p1_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase1R = LoadOplog(LoadOplogContext{.ChunkStore = ImportCidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase1Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::All, + .PopulateCache = true, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 128u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase1R.ErrorCode == 0); + CHECK(CacheStats.PutBlobCount > 0); + } + ResetCacheStats(); + + // Phase 2: fresh CidStore with only even-indexed chunks seeded. + // CacheMaxRangeCountPerRequest=1 → SubRangeCount=Min(3,1)=1 → GetBuildBlob with range + // offset (single-range legacy cache path), called once per needed chunk range. + GcManager Phase2Gc; + CidStore Phase2CidStore(Phase2Gc); + CidStoreConfiguration Phase2CidConfig = {.RootDirectory = TempDir.Path() / "partial_cas_single", + .TinyValueThreshold = 1024, + .HugeValueThreshold = 4096}; + Phase2CidStore.Initialize(Phase2CidConfig); + SeedCidStoreWithAlternateChunks(Phase2CidStore, *PartialRemoteStore, BlockHash); + + Ref Phase2Oplog = ImportProject->NewOplog(fmt::format("partial_cache_single_p2_{}", OpJobIndex++), {}); + RemoteProjectStore::Result Phase2R = LoadOplog(LoadOplogContext{.ChunkStore = Phase2CidStore, + .RemoteStore = *PartialRemoteStore, + .OptionalCache = Cache.get(), + .CacheBuildId = CacheBuildId, + .Oplog = *Phase2Oplog, + .NetworkWorkerPool = NetworkPool, + .WorkerPool = WorkerPool, + .ForceDownload = false, + .IgnoreMissingAttachments = false, + .CleanOplog = false, + .PartialBlockRequestMode = EPartialBlockRequestMode::ZenCacheOnly, + .PopulateCache = false, + .StoreLatencySec = 0.001, + .StoreMaxRangeCountPerRequest = 128u, + .CacheLatencySec = 0.001, + .CacheMaxRangeCountPerRequest = 1u, + .OptionalJobContext = &OpJobContext}); + CHECK(Phase2R.ErrorCode == 0); + CHECK(CacheStats.TotalRequestCount > 0); + } } TEST_SUITE_END(); diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp index ef82c45e0..115d6438d 100644 --- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp @@ -157,59 +157,56 @@ public: return Result; } - virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes, ESourceMode SourceMode) override + virtual LoadAttachmentsResult LoadAttachments(const std::vector& RawHashes) override { LoadAttachmentsResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); + std::string LoadRequest = fmt::format("/{}/oplog/{}/rpc"sv, m_Project, m_Oplog); - CbObject Request; + CbObject Request; + { + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "getchunks"sv); + RequestWriter.BeginObject("Request"sv); { - CbObjectWriter RequestWriter; - RequestWriter.AddString("method"sv, "getchunks"sv); - RequestWriter.BeginObject("Request"sv); + RequestWriter.BeginArray("Chunks"sv); { - RequestWriter.BeginArray("Chunks"sv); + for (const IoHash& RawHash : RawHashes) { - for (const IoHash& RawHash : RawHashes) + RequestWriter.BeginObject(); { - RequestWriter.BeginObject(); - { - RequestWriter.AddHash("RawHash", RawHash); - } - RequestWriter.EndObject(); + RequestWriter.AddHash("RawHash", RawHash); } + RequestWriter.EndObject(); } - RequestWriter.EndArray(); // "chunks" } - RequestWriter.EndObject(); - Request = RequestWriter.Save(); + RequestWriter.EndArray(); // "chunks" } + RequestWriter.EndObject(); + Request = RequestWriter.Save(); + } - HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage)); - AddStats(Response); + HttpClient::Response Response = m_Client.Post(LoadRequest, Request, HttpClient::Accept(ZenContentType::kCbPackage)); + AddStats(Response); - Result = LoadAttachmentsResult{ConvertResult(Response)}; - if (Result.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'", - RawHashes.size(), - m_ProjectStoreUrl, - m_Project, - m_Oplog, - Result.Reason); - } - else + Result = LoadAttachmentsResult{ConvertResult(Response)}; + if (Result.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching {} oplog attachments from {}/{}/{}. Reason: '{}'", + RawHashes.size(), + m_ProjectStoreUrl, + m_Project, + m_Oplog, + Result.Reason); + } + else + { + CbPackage Package = Response.AsPackage(); + std::span Attachments = Package.GetAttachments(); + Result.Chunks.reserve(Attachments.size()); + for (const CbAttachment& Attachment : Attachments) { - CbPackage Package = Response.AsPackage(); - std::span Attachments = Package.GetAttachments(); - Result.Chunks.reserve(Attachments.size()); - for (const CbAttachment& Attachment : Attachments) - { - Result.Chunks.emplace_back( - std::pair{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); - } + Result.Chunks.emplace_back( + std::pair{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); } } return Result; @@ -253,75 +250,64 @@ public: return GetKnownBlocksResult{{.ErrorCode = static_cast(HttpResponseCode::NoContent)}}; } - virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes) override + virtual GetBlockDescriptionsResult GetBlockDescriptions(std::span BlockHashes, + BuildStorageCache* OptionalCache, + const Oid& CacheBuildId) override { - ZEN_UNUSED(BlockHashes); + ZEN_UNUSED(BlockHashes, OptionalCache, CacheBuildId); return GetBlockDescriptionsResult{Result{.ErrorCode = int(HttpResponseCode::NotFound)}}; } - virtual AttachmentExistsInCacheResult AttachmentExistsInCache(std::span RawHashes) override - { - return AttachmentExistsInCacheResult{Result{.ErrorCode = 0}, std::vector(RawHashes.size(), false)}; - } - - virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash, ESourceMode SourceMode) override + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override { LoadAttachmentResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); - HttpClient::Response Response = - m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); - AddStats(Response); + std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); + HttpClient::Response Response = + m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); + AddStats(Response); - Result = LoadAttachmentResult{ConvertResult(Response)}; - if (Result.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", - m_ProjectStoreUrl, - m_Project, - m_Oplog, - RawHash, - Result.Reason); - } - Result.Bytes = Response.ResponsePayload; - Result.Bytes.MakeOwned(); + Result = LoadAttachmentResult{ConvertResult(Response)}; + if (Result.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", + m_ProjectStoreUrl, + m_Project, + m_Oplog, + RawHash, + Result.Reason); } + Result.Bytes = Response.ResponsePayload; + Result.Bytes.MakeOwned(); return Result; } virtual LoadAttachmentRangesResult LoadAttachmentRanges(const IoHash& RawHash, - std::span> Ranges, - ESourceMode SourceMode) override + std::span> Ranges) override { + ZEN_ASSERT(!Ranges.empty()); LoadAttachmentRangesResult Result; - if (SourceMode != ESourceMode::kCacheOnly) - { - std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); - HttpClient::Response Response = - m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); - AddStats(Response); + std::string LoadRequest = fmt::format("/{}/oplog/{}/{}"sv, m_Project, m_Oplog, RawHash); + HttpClient::Response Response = + m_Client.Download(LoadRequest, m_TempFilePath, HttpClient::Accept(ZenContentType::kCompressedBinary)); + AddStats(Response); - Result = LoadAttachmentRangesResult{ConvertResult(Response)}; - if (Result.ErrorCode) - { - Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", - m_ProjectStoreUrl, - m_Project, - m_Oplog, - RawHash, - Result.Reason); - } - else - { - Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); - } + Result = LoadAttachmentRangesResult{ConvertResult(Response)}; + if (Result.ErrorCode) + { + Result.Reason = fmt::format("Failed fetching oplog attachment from {}/{}/{}/{}. Reason: '{}'", + m_ProjectStoreUrl, + m_Project, + m_Oplog, + RawHash, + Result.Reason); + } + else + { + Result.Ranges = std::vector>(Ranges.begin(), Ranges.end()); } return Result; } - virtual void Flush() override {} - private: void AddStats(const HttpClient::Response& Result) { diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index 459e044eb..38d97765b 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -177,6 +177,14 @@ HttpBuildStoreService::GetBlobRequest(HttpRouterRequest& Req) uint64_t RangeLength = RangeView["length"sv].AsUInt64(); OffsetAndLengthPairs.push_back(std::make_pair(RangeOffset, RangeLength)); } + if (OffsetAndLengthPairs.size() > MaxRangeCountPerRequestSupported) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Number of ranges ({}) for blob request exceeds maximum range count {}", + OffsetAndLengthPairs.size(), + MaxRangeCountPerRequestSupported)); + } } if (OffsetAndLengthPairs.empty()) { @@ -661,6 +669,11 @@ HttpBuildStoreService::HandleStatusRequest(HttpServerRequest& Request) ZEN_TRACE_CPU("HttpBuildStoreService::Status"); CbObjectWriter Cbo; Cbo << "ok" << true; + Cbo.BeginObject("capabilities"); + { + Cbo << "maxrangecountperrequest" << MaxRangeCountPerRequestSupported; + } + Cbo.EndObject(); // capabilities Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); } diff --git a/src/zenserver/storage/buildstore/httpbuildstore.h b/src/zenserver/storage/buildstore/httpbuildstore.h index e10986411..5fa7cd642 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.h +++ b/src/zenserver/storage/buildstore/httpbuildstore.h @@ -45,6 +45,8 @@ private: inline LoggerRef Log() { return m_Log; } + static constexpr uint32_t MaxRangeCountPerRequestSupported = 256u; + LoggerRef m_Log; void PutBlobRequest(HttpRouterRequest& Req); diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index 2b5474d00..0ec6faea3 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -13,7 +13,12 @@ #include #include #include +#include #include +#include +#include +#include +#include #include #include #include @@ -244,8 +249,22 @@ namespace { { std::shared_ptr Store; std::string Description; - double HostLatencySec = -1.0; - double CacheLatencySec = -1.0; + double LatencySec = -1.0; + uint64_t MaxRangeCountPerRequest = 1; + + struct Cache + { + std::unique_ptr Http; + std::unique_ptr Cache; + Oid BuildsId = Oid::Zero; + std::string Description; + double LatencySec = -1.0; + uint64_t MaxRangeCountPerRequest = 1; + BuildStorageCache::Statistics Stats; + bool Populate = false; + }; + + std::unique_ptr OptionalCache; }; CreateRemoteStoreResult CreateRemoteStore(LoggerRef InLog, @@ -262,9 +281,7 @@ namespace { using namespace std::literals; - std::shared_ptr RemoteStore; - double HostLatencySec = -1.0; - double CacheLatencySec = -1.0; + CreateRemoteStoreResult Result; if (CbObjectView File = Params["file"sv].AsObjectView(); File) { @@ -282,6 +299,9 @@ namespace { bool ForceDisableBlocks = File["disableblocks"sv].AsBool(false); bool ForceEnableTempBlocks = File["enabletempblocks"sv].AsBool(false); + Result.LatencySec = 0; + Result.MaxRangeCountPerRequest = 1; + FileRemoteStoreOptions Options = { RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, FolderPath, @@ -289,7 +309,7 @@ namespace { std::string(OptionalBaseName), ForceDisableBlocks, ForceEnableTempBlocks}; - RemoteStore = CreateFileRemoteStore(Log(), Options); + Result.Store = CreateFileRemoteStore(Log(), Options); } if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud) @@ -367,21 +387,32 @@ namespace { bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false); bool AssumeHttp2 = Cloud["assumehttp2"sv].AsBool(false); - JupiterRemoteStoreOptions Options = { - RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, - Url, - std::string(Namespace), - std::string(Bucket), - Key, - BaseKey, - std::string(OpenIdProvider), - AccessToken, - AuthManager, - OidcExePath, - ForceDisableBlocks, - ForceDisableTempBlocks, - AssumeHttp2}; - RemoteStore = CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true); + if (JupiterEndpointTestResult TestResult = TestJupiterEndpoint(Url, AssumeHttp2, /*Verbose*/ false); TestResult.Success) + { + Result.LatencySec = TestResult.LatencySeconds; + Result.MaxRangeCountPerRequest = TestResult.MaxRangeCountPerRequest; + + JupiterRemoteStoreOptions Options = { + RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + Url, + std::string(Namespace), + std::string(Bucket), + Key, + BaseKey, + std::string(OpenIdProvider), + AccessToken, + AuthManager, + OidcExePath, + ForceDisableBlocks, + ForceDisableTempBlocks, + AssumeHttp2}; + Result.Store = + CreateJupiterRemoteStore(Log(), Options, TempFilePath, /*Quiet*/ false, /*Unattended*/ false, /*Hidden*/ true); + } + else + { + return {nullptr, fmt::format("Unable to connect to jupiter host '{}'", Url)}; + } } if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen) @@ -397,12 +428,13 @@ namespace { { return {nullptr, "Missing oplog"}; } + ZenRemoteStoreOptions Options = { RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, std::string(Url), std::string(Project), std::string(Oplog)}; - RemoteStore = CreateZenRemoteStore(Log(), Options, TempFilePath); + Result.Store = CreateZenRemoteStore(Log(), Options, TempFilePath); } if (CbObjectView Builds = Params["builds"sv].AsObjectView(); Builds) @@ -475,11 +507,76 @@ namespace { MemoryView MetaDataSection = Builds["metadata"sv].AsBinaryView(); IoBuffer MetaData(IoBuffer::Wrap, MetaDataSection.GetData(), MetaDataSection.GetSize()); + auto EnsureHttps = [](const std::string& Host, std::string_view PreferredProtocol) { + if (!Host.empty() && Host.find("://"sv) == std::string::npos) + { + // Assume https URL + return fmt::format("{}://{}"sv, PreferredProtocol, Host); + } + return Host; + }; + + Host = EnsureHttps(Host, "https"); + OverrideHost = EnsureHttps(OverrideHost, "https"); + ZenHost = EnsureHttps(ZenHost, "http"); + + std::function TokenProvider; + if (!OpenIdProvider.empty()) + { + TokenProvider = httpclientauth::CreateFromOpenIdProvider(AuthManager, OpenIdProvider); + } + else if (!AccessToken.empty()) + { + TokenProvider = httpclientauth::CreateFromStaticToken(AccessToken); + } + else if (!OidcExePath.empty()) + { + if (auto TokenProviderMaybe = httpclientauth::CreateFromOidcTokenExecutable(OidcExePath, + Host.empty() ? OverrideHost : Host, + /*Quiet*/ false, + /*Unattended*/ false, + /*Hidden*/ true); + TokenProviderMaybe) + { + TokenProvider = TokenProviderMaybe.value(); + } + } + + if (!TokenProvider) + { + TokenProvider = httpclientauth::CreateFromDefaultOpenIdProvider(AuthManager); + } + + BuildStorageResolveResult ResolveResult; + { + HttpClientSettings ClientSettings{.LogCategory = "httpbuildsclient", + .AccessTokenProvider = TokenProvider, + .AssumeHttp2 = AssumeHttp2, + .AllowResume = true, + .RetryCount = 2}; + + std::unique_ptr Output(CreateStandardLogOutput(Log())); + + try + { + ResolveResult = ResolveBuildStorage(*Output, + ClientSettings, + Host, + OverrideHost, + ZenHost, + ZenCacheResolveMode::Discovery, + /*Verbose*/ false); + } + catch (const std::exception& Ex) + { + return {nullptr, fmt::format("Failed resolving storage host and cache. Reason: '{}'", Ex.what())}; + } + } + Result.LatencySec = ResolveResult.Cloud.LatencySec; + Result.MaxRangeCountPerRequest = ResolveResult.Cloud.Caps.MaxRangeCountPerRequest; + BuildsRemoteStoreOptions Options = { RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunksPerBlock = 1000, .MaxChunkEmbedSize = MaxChunkEmbedSize}, - Host, - OverrideHost, - ZenHost, std::string(Namespace), std::string(Bucket), BuildId, @@ -489,30 +586,43 @@ namespace { OidcExePath, ForceDisableBlocks, ForceDisableTempBlocks, - AssumeHttp2, - PopulateCache, MetaData, MaximumInMemoryDownloadSize}; - RemoteStore = CreateJupiterBuildsRemoteStore(Log(), - Options, - TempFilePath, - /*Quiet*/ false, - /*Unattended*/ false, - /*Hidden*/ true, - GetTinyWorkerPool(EWorkloadType::Background), - HostLatencySec, - CacheLatencySec); + Result.Store = CreateJupiterBuildsRemoteStore(Log(), ResolveResult, std::move(TokenProvider), Options, TempFilePath); + + if (!ResolveResult.Cache.Address.empty()) + { + Result.OptionalCache = std::make_unique(); + + HttpClientSettings CacheClientSettings{.LogCategory = "httpcacheclient", + .ConnectTimeout = std::chrono::milliseconds{3000}, + .Timeout = std::chrono::milliseconds{30000}, + .AssumeHttp2 = ResolveResult.Cache.AssumeHttp2, + .AllowResume = true, + .RetryCount = 0, + .MaximumInMemoryDownloadSize = MaximumInMemoryDownloadSize}; + + Result.OptionalCache->Http = std::make_unique(ResolveResult.Cache.Address, CacheClientSettings); + Result.OptionalCache->Cache = CreateZenBuildStorageCache(*Result.OptionalCache->Http, + Result.OptionalCache->Stats, + Namespace, + Bucket, + TempFilePath, + GetTinyWorkerPool(EWorkloadType::Background)); + Result.OptionalCache->BuildsId = BuildId; + Result.OptionalCache->LatencySec = ResolveResult.Cache.LatencySec; + Result.OptionalCache->MaxRangeCountPerRequest = ResolveResult.Cache.Caps.MaxRangeCountPerRequest; + Result.OptionalCache->Populate = PopulateCache; + Result.OptionalCache->Description = + fmt::format("[zenserver] {} namespace {} bucket {}", ResolveResult.Cache.Address, Namespace, Bucket); + } } - - if (!RemoteStore) + if (!Result.Store) { return {nullptr, "Unknown remote store type"}; } - return CreateRemoteStoreResult{.Store = std::move(RemoteStore), - .Description = "", - .HostLatencySec = HostLatencySec, - .CacheLatencySec = CacheLatencySec}; + return Result; } std::pair ConvertResult(const RemoteProjectStore::Result& Result) @@ -2679,38 +2789,36 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) EPartialBlockRequestMode PartialBlockRequestMode = PartialBlockRequestModeFromString(Params["partialblockrequestmode"sv].AsString("true")); - CreateRemoteStoreResult RemoteStoreResult = CreateRemoteStore(Log(), - Params, - m_AuthMgr, - MaxBlockSize, - MaxChunkEmbedSize, - GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory), - Oplog->TempPath()); + std::shared_ptr RemoteStoreResult = + std::make_shared(CreateRemoteStore(Log(), + Params, + m_AuthMgr, + MaxBlockSize, + MaxChunkEmbedSize, + GetMaxMemoryBufferSize(MaxBlockSize, BoostWorkerMemory), + Oplog->TempPath())); - if (RemoteStoreResult.Store == nullptr) + if (RemoteStoreResult->Store == nullptr) { - return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult.Description); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, RemoteStoreResult->Description); } - std::shared_ptr RemoteStore = std::move(RemoteStoreResult.Store); - RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); JobId JobId = m_JobQueue.QueueJob( fmt::format("Import oplog '{}/{}'", Project->Identifier, Oplog->OplogId()), [this, - ChunkStore = &m_CidStore, - ActualRemoteStore = std::move(RemoteStore), + RemoteStoreResult = std::move(RemoteStoreResult), Oplog, Force, IgnoreMissingAttachments, CleanOplog, PartialBlockRequestMode, - HostLatencySec = RemoteStoreResult.HostLatencySec, - CacheLatencySec = RemoteStoreResult.CacheLatencySec, BoostWorkerCount](JobContext& Context) { - Context.ReportMessage(fmt::format("Loading oplog '{}/{}' from {}", - Oplog->GetOuterProjectIdentifier(), - Oplog->OplogId(), - ActualRemoteStore->GetInfo().Description)); + Context.ReportMessage( + fmt::format("Loading oplog '{}/{}'\n Host: {}\n Cache: {}", + Oplog->GetOuterProjectIdentifier(), + Oplog->OplogId(), + RemoteStoreResult->Store->GetInfo().Description, + RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Description : "")); Ref Workers = GetThreadWorkers(BoostWorkerCount, /*SingleThreaded*/ false); @@ -2718,19 +2826,26 @@ HttpProjectService::HandleRpcRequest(HttpRouterRequest& Req) WorkerThreadPool& NetworkWorkerPool = Workers->GetNetworkPool(); Context.ReportMessage(fmt::format("{}", Workers->GetWorkersInfo())); - - RemoteProjectStore::Result Result = LoadOplog(m_CidStore, - *ActualRemoteStore, - *Oplog, - NetworkWorkerPool, - WorkerPool, - Force, - IgnoreMissingAttachments, - CleanOplog, - PartialBlockRequestMode, - HostLatencySec, - CacheLatencySec, - &Context); + RemoteProjectStore::Result Result = LoadOplog(LoadOplogContext{ + .ChunkStore = m_CidStore, + .RemoteStore = *RemoteStoreResult->Store, + .OptionalCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Cache.get() : nullptr, + .CacheBuildId = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->BuildsId : Oid::Zero, + .OptionalCacheStats = RemoteStoreResult->OptionalCache ? &RemoteStoreResult->OptionalCache->Stats : nullptr, + .Oplog = *Oplog, + .NetworkWorkerPool = NetworkWorkerPool, + .WorkerPool = WorkerPool, + .ForceDownload = Force, + .IgnoreMissingAttachments = IgnoreMissingAttachments, + .CleanOplog = CleanOplog, + .PartialBlockRequestMode = PartialBlockRequestMode, + .PopulateCache = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->Populate : false, + .StoreLatencySec = RemoteStoreResult->LatencySec, + .StoreMaxRangeCountPerRequest = RemoteStoreResult->MaxRangeCountPerRequest, + .CacheLatencySec = RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->LatencySec : -1.0, + .CacheMaxRangeCountPerRequest = + RemoteStoreResult->OptionalCache ? RemoteStoreResult->OptionalCache->MaxRangeCountPerRequest : 0, + .OptionalJobContext = &Context}); auto Response = ConvertResult(Result); ZEN_INFO("LoadOplog: Status: {} '{}'", ToString(Response.first), Response.second); if (!IsHttpSuccessCode(Response.first)) -- cgit v1.2.3 From b37b34ea6ad906f54e8104526e77ba66aed997da Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Mon, 9 Mar 2026 17:43:08 +0100 Subject: Dashboard overhaul, compute integration (#814) - **Frontend dashboard overhaul**: Unified compute/main dashboards into a single shared UI. Added new pages for cache, projects, metrics, sessions, info (build/runtime config, system stats). Added live-update via WebSockets with pause control, sortable detail tables, themed styling. Refactored compute/hub/orchestrator pages into modular JS. - **HTTP server fixes and stats**: Fixed http.sys local-only fallback when default port is in use, implemented root endpoint redirect for http.sys, fixed Linux/Mac port reuse. Added /stats endpoint exposing HTTP server metrics (bytes transferred, request rates). Added WebSocket stats tracking. - **OTEL/diagnostics hardening**: Improved OTLP HTTP exporter with better error handling and resilience. Extended diagnostics services configuration. - **Session management**: Added new sessions service with HTTP endpoints for registering, updating, querying, and removing sessions. Includes session log file support. This is still WIP. - **CLI subcommand support**: Added support for commands with subcommands in the zen CLI tool, with improved command dispatch. - **Misc**: Exposed CPU usage/hostname to frontend, fixed JS compact binary float32/float64 decoding, limited projects displayed on front page to 25 sorted by last access, added vscode:// link support. Also contains some fixes from TSAN analysis. --- src/zen/zen.cpp | 78 ++ src/zen/zen.h | 37 + src/zenbase/include/zenbase/refcount.h | 3 + src/zencore/include/zencore/logging/tracesink.h | 4 + src/zencore/include/zencore/sentryintegration.h | 1 + src/zencore/include/zencore/system.h | 2 + src/zencore/logging/tracesink.cpp | 4 + src/zencore/sentryintegration.cpp | 19 +- src/zencore/system.cpp | 60 +- src/zencore/testutils.cpp | 2 +- src/zenhttp/httpserver.cpp | 34 + src/zenhttp/include/zenhttp/httpclient.h | 9 + src/zenhttp/include/zenhttp/httpserver.h | 73 +- src/zenhttp/include/zenhttp/httpstats.h | 47 +- src/zenhttp/monitoring/httpstats.cpp | 195 ++++- src/zenhttp/servers/httpasio.cpp | 104 +-- src/zenhttp/servers/httpplugin.cpp | 2 + src/zenhttp/servers/httpsys.cpp | 113 ++- src/zenhttp/servers/wsasio.cpp | 18 +- src/zenhttp/servers/wsasio.h | 8 +- src/zenhttp/servers/wshttpsys.cpp | 23 +- src/zenhttp/servers/wshttpsys.h | 5 +- src/zenhttp/servers/wstest.cpp | 3 +- .../include/zenremotestore/jupiter/jupiterhost.h | 1 + .../projectstore/remoteprojectstore.cpp | 5 +- src/zenserver/compute/computeserver.cpp | 4 +- src/zenserver/compute/computeserver.h | 1 - src/zenserver/diag/diagsvcs.cpp | 31 + src/zenserver/diag/diagsvcs.h | 15 +- src/zenserver/diag/otlphttp.cpp | 59 +- src/zenserver/diag/otlphttp.h | 13 +- src/zenserver/frontend/html.zip | Bin 319315 -> 406051 bytes src/zenserver/frontend/html/banner.js | 338 +++++++++ src/zenserver/frontend/html/compute/banner.js | 321 -------- src/zenserver/frontend/html/compute/compute.html | 327 +++------ src/zenserver/frontend/html/compute/hub.html | 154 +--- src/zenserver/frontend/html/compute/nav.js | 79 -- .../frontend/html/compute/orchestrator.html | 205 +----- src/zenserver/frontend/html/index.html | 3 + src/zenserver/frontend/html/nav.js | 79 ++ src/zenserver/frontend/html/pages/cache.js | 690 ++++++++++++++++++ src/zenserver/frontend/html/pages/compute.js | 693 ++++++++++++++++++ src/zenserver/frontend/html/pages/entry.js | 4 +- src/zenserver/frontend/html/pages/hub.js | 122 ++++ src/zenserver/frontend/html/pages/info.js | 261 +++++++ src/zenserver/frontend/html/pages/map.js | 4 +- src/zenserver/frontend/html/pages/metrics.js | 232 ++++++ src/zenserver/frontend/html/pages/oplog.js | 2 +- src/zenserver/frontend/html/pages/orchestrator.js | 405 +++++++++++ src/zenserver/frontend/html/pages/page.js | 69 +- src/zenserver/frontend/html/pages/project.js | 2 +- src/zenserver/frontend/html/pages/projects.js | 447 ++++++++++++ src/zenserver/frontend/html/pages/sessions.js | 61 ++ src/zenserver/frontend/html/pages/start.js | 327 ++++++--- src/zenserver/frontend/html/pages/stat.js | 2 +- src/zenserver/frontend/html/pages/tree.js | 2 +- src/zenserver/frontend/html/pages/zcache.js | 8 +- src/zenserver/frontend/html/theme.js | 116 +++ src/zenserver/frontend/html/util/compactbinary.js | 4 +- src/zenserver/frontend/html/util/friendly.js | 21 + src/zenserver/frontend/html/util/widgets.js | 138 +++- src/zenserver/frontend/html/zen.css | 809 +++++++++++++++++---- src/zenserver/hub/zenhubserver.cpp | 2 + src/zenserver/sessions/httpsessions.cpp | 264 +++++++ src/zenserver/sessions/httpsessions.h | 55 ++ src/zenserver/sessions/sessions.cpp | 150 ++++ src/zenserver/sessions/sessions.h | 83 +++ .../storage/buildstore/httpbuildstore.cpp | 12 +- src/zenserver/storage/buildstore/httpbuildstore.h | 5 +- .../storage/cache/httpstructuredcache.cpp | 137 +++- src/zenserver/storage/cache/httpstructuredcache.h | 11 +- .../storage/projectstore/httpprojectstore.cpp | 12 +- .../storage/projectstore/httpprojectstore.h | 5 +- .../storage/workspaces/httpworkspaces.cpp | 12 +- src/zenserver/storage/workspaces/httpworkspaces.h | 5 +- src/zenserver/storage/zenstorageserver.cpp | 16 +- src/zenserver/storage/zenstorageserver.h | 4 +- src/zenserver/zenserver.cpp | 103 ++- src/zenserver/zenserver.h | 17 +- src/zenstore/cache/cachedisklayer.cpp | 4 +- src/zenstore/projectstore.cpp | 7 + src/zentelemetry/include/zentelemetry/stats.h | 1 + 82 files changed, 6399 insertions(+), 1404 deletions(-) create mode 100644 src/zenserver/frontend/html/banner.js delete mode 100644 src/zenserver/frontend/html/compute/banner.js delete mode 100644 src/zenserver/frontend/html/compute/nav.js create mode 100644 src/zenserver/frontend/html/nav.js create mode 100644 src/zenserver/frontend/html/pages/cache.js create mode 100644 src/zenserver/frontend/html/pages/compute.js create mode 100644 src/zenserver/frontend/html/pages/hub.js create mode 100644 src/zenserver/frontend/html/pages/info.js create mode 100644 src/zenserver/frontend/html/pages/metrics.js create mode 100644 src/zenserver/frontend/html/pages/orchestrator.js create mode 100644 src/zenserver/frontend/html/pages/projects.js create mode 100644 src/zenserver/frontend/html/pages/sessions.js create mode 100644 src/zenserver/frontend/html/theme.js create mode 100644 src/zenserver/sessions/httpsessions.cpp create mode 100644 src/zenserver/sessions/httpsessions.h create mode 100644 src/zenserver/sessions/sessions.cpp create mode 100644 src/zenserver/sessions/sessions.h (limited to 'src') diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 7f7afa322..9a466da2e 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -194,6 +194,84 @@ ZenCmdBase::GetSubCommand(cxxopts::Options&, return argc; } +ZenSubCmdBase::ZenSubCmdBase(std::string_view Name, std::string_view Description) +: m_SubOptions(std::string(Name), std::string(Description)) +{ + m_SubOptions.add_options()("h,help", "Print help"); +} + +void +ZenCmdWithSubCommands::AddSubCommand(ZenSubCmdBase& SubCmd) +{ + m_SubCommands.push_back(&SubCmd); +} + +bool +ZenCmdWithSubCommands::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOptions*/) +{ + return true; +} + +void +ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + std::vector SubOptionPtrs; + SubOptionPtrs.reserve(m_SubCommands.size()); + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + SubOptionPtrs.push_back(&SubCmd->SubOptions()); + } + + cxxopts::Options* MatchedSubOption = nullptr; + std::vector SubCommandArguments; + int ParentArgc = GetSubCommand(Options(), argc, argv, SubOptionPtrs, MatchedSubOption, SubCommandArguments); + + if (!ParseOptions(Options(), ParentArgc, argv)) + { + return; + } + + if (MatchedSubOption == nullptr) + { + ExtendableStringBuilder<128> VerbList; + for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands) + { + if (!First) + { + VerbList.Append(", "); + } + VerbList.Append(SubCmd->SubOptions().program()); + First = false; + } + throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), Options().help()); + } + + ZenSubCmdBase* MatchedSubCmd = nullptr; + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + if (&SubCmd->SubOptions() == MatchedSubOption) + { + MatchedSubCmd = SubCmd; + break; + } + } + ZEN_ASSERT(MatchedSubCmd != nullptr); + + // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand + // works without requiring parent options like --hosturl to be populated. + if (!ParseOptions(*MatchedSubOption, gsl::narrow(SubCommandArguments.size()), SubCommandArguments.data())) + { + return; + } + + if (!OnParentOptionsParsed(GlobalOptions)) + { + return; + } + + MatchedSubCmd->Run(GlobalOptions); +} + static ReturnCode GetReturnCodeFromHttpResult(const HttpClientError& Ex) { diff --git a/src/zen/zen.h b/src/zen/zen.h index e3481beea..06e5356a6 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -79,4 +79,41 @@ class CacheStoreCommand : public ZenCmdBase virtual ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; } }; +// Base for individual subcommands +class ZenSubCmdBase +{ +public: + ZenSubCmdBase(std::string_view Name, std::string_view Description); + virtual ~ZenSubCmdBase() = default; + cxxopts::Options& SubOptions() { return m_SubOptions; } + virtual void Run(const ZenCliOptions& GlobalOptions) = 0; + +protected: + cxxopts::Options m_SubOptions; +}; + +// Base for commands that host subcommands - handles all dispatch boilerplate +class ZenCmdWithSubCommands : public ZenCmdBase +{ +public: + void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) final; + +protected: + void AddSubCommand(ZenSubCmdBase& SubCmd); + virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions); + +private: + std::vector m_SubCommands; +}; + +class CacheStoreCmdWithSubCommands : public ZenCmdWithSubCommands +{ + ZenCmdCategory& CommandCategory() const override { return g_CacheStoreCategory; } +}; + +class StorageCmdWithSubCommands : public ZenCmdWithSubCommands +{ + ZenCmdCategory& CommandCategory() const override { return g_StorageCategory; } +}; + } // namespace zen diff --git a/src/zenbase/include/zenbase/refcount.h b/src/zenbase/include/zenbase/refcount.h index 40ad7bca5..08bc6ae54 100644 --- a/src/zenbase/include/zenbase/refcount.h +++ b/src/zenbase/include/zenbase/refcount.h @@ -51,6 +51,9 @@ private: * NOTE: Unlike RefCounted, this class deletes the derived type when the reference count reaches zero. * It has no virtual destructor, so it's important that you either don't derive from it further, * or ensure that the derived class has a virtual destructor. + * + * This class is useful when you want to avoid adding a vtable to a class just to implement + * reference counting. */ template diff --git a/src/zencore/include/zencore/logging/tracesink.h b/src/zencore/include/zencore/logging/tracesink.h index e63d838b4..785c51e10 100644 --- a/src/zencore/include/zencore/logging/tracesink.h +++ b/src/zencore/include/zencore/logging/tracesink.h @@ -6,6 +6,8 @@ namespace zen::logging { +#if ZEN_WITH_TRACE + /** * A logging sink that forwards log messages to the trace system. * @@ -20,4 +22,6 @@ public: void SetFormatter(std::unique_ptr InFormatter) override; }; +#endif + } // namespace zen::logging diff --git a/src/zencore/include/zencore/sentryintegration.h b/src/zencore/include/zencore/sentryintegration.h index a4e33d69e..27e5a8a82 100644 --- a/src/zencore/include/zencore/sentryintegration.h +++ b/src/zencore/include/zencore/sentryintegration.h @@ -40,6 +40,7 @@ public: }; void Initialize(const Config& Conf, const std::string& CommandLine); + void Close(); void LogStartupInformation(); static void ClearCaches(); diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index fecbe2dbe..a67999e52 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -14,6 +14,7 @@ class CbWriter; std::string GetMachineName(); std::string_view GetOperatingSystemName(); +std::string GetOperatingSystemVersion(); std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos" std::string_view GetCpuName(); @@ -28,6 +29,7 @@ struct SystemMetrics uint64_t AvailVirtualMemoryMiB = 0; uint64_t PageFileMiB = 0; uint64_t AvailPageFileMiB = 0; + uint64_t UptimeSeconds = 0; }; /// Extended metrics that include CPU usage percentage, which requires diff --git a/src/zencore/logging/tracesink.cpp b/src/zencore/logging/tracesink.cpp index e3533327b..8a6f4e40c 100644 --- a/src/zencore/logging/tracesink.cpp +++ b/src/zencore/logging/tracesink.cpp @@ -7,6 +7,8 @@ #include #include +#if ZEN_WITH_TRACE + namespace zen::logging { UE_TRACE_CHANNEL_DEFINE(LogChannel) @@ -86,3 +88,5 @@ TraceSink::SetFormatter(std::unique_ptr /*InFormatter*/) } } // namespace zen::logging + +#endif diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index e39b8438d..8d087e8c6 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -128,6 +128,8 @@ namespace zen { # if ZEN_USE_SENTRY ZEN_DEFINE_LOG_CATEGORY_STATIC(LogSentry, "sentry-sdk"); +static std::atomic s_SentryLogEnabled{true}; + static void SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) { @@ -147,14 +149,15 @@ SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[may } // SentryLogFunction can be called before the logging system is initialized - // (during sentry_init which runs before InitializeLogging). Fall back to - // console logging when the category logger is not yet available. + // (during sentry_init which runs before InitializeLogging), or after it has + // been shut down (during sentry_close on a background worker thread). Fall + // back to console logging when the category logger is not available. // // Since we want to default to WARN level but this runs before logging has // been configured, we ignore the callbacks for DEBUG/INFO explicitly here // which means users don't see every possible log message if they're trying // to configure the levels using --log-debug=sentry-sdk - if (!TheDefaultLogger) + if (!TheDefaultLogger || !s_SentryLogEnabled.load(std::memory_order_acquire)) { switch (Level) { @@ -211,12 +214,22 @@ SentryIntegration::SentryIntegration() } SentryIntegration::~SentryIntegration() +{ + Close(); +} + +void +SentryIntegration::Close() { if (m_IsInitialized && m_SentryErrorCode == 0) { logging::SetErrorLog(""); m_SentryAssert.reset(); + // Disable spdlog forwarding before sentry_close() since its background + // worker thread may still log during shutdown via SentryLogFunction + s_SentryLogEnabled.store(false, std::memory_order_release); sentry_close(); + m_IsInitialized = false; } } diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index 833d3c04b..141450b84 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -135,6 +136,8 @@ GetSystemMetrics() Metrics.AvailPageFileMiB = MemStatus.ullAvailPageFile / 1024 / 1024; } + Metrics.UptimeSeconds = GetTickCount64() / 1000; + return Metrics; } #elif ZEN_PLATFORM_LINUX @@ -226,6 +229,17 @@ GetSystemMetrics() Metrics.VirtualMemoryMiB = Metrics.SystemMemoryMiB; Metrics.AvailVirtualMemoryMiB = Metrics.AvailSystemMemoryMiB; + // System uptime + if (FILE* UptimeFile = fopen("/proc/uptime", "r")) + { + double UptimeSec = 0; + if (fscanf(UptimeFile, "%lf", &UptimeSec) == 1) + { + Metrics.UptimeSeconds = static_cast(UptimeSec); + } + fclose(UptimeFile); + } + // Parse /proc/meminfo for swap/page file information Metrics.PageFileMiB = 0; Metrics.AvailPageFileMiB = 0; @@ -318,6 +332,18 @@ GetSystemMetrics() Metrics.PageFileMiB = SwapUsage.xsu_total / 1024 / 1024; Metrics.AvailPageFileMiB = (SwapUsage.xsu_total - SwapUsage.xsu_used) / 1024 / 1024; + // System uptime via boot time + { + struct timeval BootTime + { + }; + Size = sizeof(BootTime); + if (sysctlbyname("kern.boottime", &BootTime, &Size, nullptr, 0) == 0) + { + Metrics.UptimeSeconds = static_cast(time(nullptr) - BootTime.tv_sec); + } + } + return Metrics; } #else @@ -574,6 +600,38 @@ GetOperatingSystemName() return ZEN_PLATFORM_NAME; } +std::string +GetOperatingSystemVersion() +{ +#if ZEN_PLATFORM_WINDOWS + // Use RtlGetVersion to avoid the compatibility shim that GetVersionEx applies + using RtlGetVersionFn = LONG(WINAPI*)(PRTL_OSVERSIONINFOW); + RTL_OSVERSIONINFOW OsVer{.dwOSVersionInfoSize = sizeof(OsVer)}; + if (auto Fn = (RtlGetVersionFn)GetProcAddress(GetModuleHandleW(L"ntdll.dll"), "RtlGetVersion")) + { + Fn(&OsVer); + } + return fmt::format("Windows {}.{} Build {}", OsVer.dwMajorVersion, OsVer.dwMinorVersion, OsVer.dwBuildNumber); +#elif ZEN_PLATFORM_LINUX + struct utsname Info + { + }; + if (uname(&Info) == 0) + { + return fmt::format("{} {}", Info.sysname, Info.release); + } + return "Linux"; +#elif ZEN_PLATFORM_MAC + char OsVersion[64] = ""; + size_t Size = sizeof(OsVersion); + if (sysctlbyname("kern.osproductversion", OsVersion, &Size, nullptr, 0) == 0) + { + return fmt::format("macOS {}", OsVersion); + } + return "macOS"; +#endif +} + std::string_view GetRuntimePlatformName() { @@ -608,7 +666,7 @@ Describe(const SystemMetrics& Metrics, CbWriter& Writer) Writer << "cpu_count" << Metrics.CpuCount << "core_count" << Metrics.CoreCount << "lp_count" << Metrics.LogicalProcessorCount << "total_memory_mb" << Metrics.SystemMemoryMiB << "avail_memory_mb" << Metrics.AvailSystemMemoryMiB << "total_virtual_mb" << Metrics.VirtualMemoryMiB << "avail_virtual_mb" << Metrics.AvailVirtualMemoryMiB << "total_pagefile_mb" << Metrics.PageFileMiB - << "avail_pagefile_mb" << Metrics.AvailPageFileMiB; + << "avail_pagefile_mb" << Metrics.AvailPageFileMiB << "uptime_seconds" << Metrics.UptimeSeconds; } void diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp index 5bc2841ae..0cd3f8121 100644 --- a/src/zencore/testutils.cpp +++ b/src/zencore/testutils.cpp @@ -46,7 +46,7 @@ ScopedTemporaryDirectory::~ScopedTemporaryDirectory() IoBuffer CreateRandomBlob(uint64_t Size) { - static FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1}; + thread_local FastRandom Rand{.Seed = 0x7CEBF54E45B9F5D1}; return CreateRandomBlob(Rand, Size); }; diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index d798c46d9..9bae95690 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -1094,6 +1095,39 @@ HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) OnSetHttpRequestFilter(RequestFilter); } +CbObject +HttpServer::CollectStats() +{ + CbObjectWriter Cbo; + + metrics::EmitSnapshot("requests", m_RequestMeter, Cbo); + + Cbo.BeginObject("bytes"); + { + Cbo << "received" << GetTotalBytesReceived(); + Cbo << "sent" << GetTotalBytesSent(); + } + Cbo.EndObject(); + + Cbo.BeginObject("websockets"); + { + Cbo << "active_connections" << GetActiveWebSocketConnectionCount(); + Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed); + Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed); + Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed); + Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpServer::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index bec4984db..1bb36a298 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -118,6 +118,15 @@ private: class HttpClientBase; +/** HTTP Client + * + * This is safe for use on multiple threads simultaneously, as each + * instance maintains an internal connection pool and will synchronize + * access to it as needed. + * + * Uses libcurl under the hood. We currently only use HTTP 1.1 features. + * + */ class HttpClient { public: diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index c1152dc3e..0e1714669 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -13,6 +13,8 @@ #include #include +#include + #include #include #include @@ -203,12 +205,34 @@ private: int m_UriPrefixLength = 0; }; +struct IHttpStatsProvider +{ + /** Handle an HTTP stats request, writing the response directly. + * Implementations may inspect query parameters on the request + * to include optional detailed breakdowns. + */ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; + + /** Return the provider's current stats as a CbObject snapshot. + * Used by the WebSocket push thread to broadcast live updates + * without requiring an HttpServerRequest. Providers that do + * not override this will be skipped in WebSocket broadcasts. + */ + virtual CbObject CollectStats() { return {}; } +}; + +struct IHttpStatsService +{ + virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; + virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; +}; + /** HTTP server * * Implements the main event loop to service HTTP requests, and handles routing * requests to the appropriate handler as registered via RegisterService */ -class HttpServer : public RefCounted +class HttpServer : public RefCounted, public IHttpStatsProvider { public: void RegisterService(HttpService& Service); @@ -235,10 +259,46 @@ public: virtual uint64_t GetTotalBytesReceived() const { return 0; } virtual uint64_t GetTotalBytesSent() const { return 0; } + /** Mark that a request has been handled. Called by server implementations. */ + void MarkRequest() { m_RequestMeter.Mark(); } + + /** Set a default redirect path for root requests */ + void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; } + + std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } + + /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } + void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } + uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } + + /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + void OnWebSocketFrameReceived(uint64_t Bytes) + { + m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); + m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed); + } + void OnWebSocketFrameSent(uint64_t Bytes) + { + m_WsFramesSent.fetch_add(1, std::memory_order_relaxed); + m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed); + } + + // IHttpStatsProvider + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + private: std::vector m_KnownServices; int m_EffectivePort = 0; std::string m_ExternalHost; + metrics::Meter m_RequestMeter; + std::string m_DefaultRedirect; + std::atomic m_ActiveWebSocketConnections{0}; + std::atomic m_WsFramesReceived{0}; + std::atomic m_WsFramesSent{0}; + std::atomic m_WsBytesReceived{0}; + std::atomic m_WsBytesSent{0}; virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; @@ -456,17 +516,6 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref& PackageHandlerRef); -struct IHttpStatsProvider -{ - virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; -}; - -struct IHttpStatsService -{ - virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; - virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; -}; - void http_forcelink(); // internal void websocket_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index e6fea6765..460315faf 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -3,23 +3,50 @@ #pragma once #include +#include #include +#include +#include #include +#include +#include +#include + +ZEN_THIRD_PARTY_INCLUDES_START +#include +#include +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -class HttpStatsService : public HttpService, public IHttpStatsService +class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - HttpStatsService(); + /// Construct without an io_context — optionally uses a dedicated push thread + /// for WebSocket stats broadcasting. + explicit HttpStatsService(bool EnableWebSockets = false); + + /// Construct with an external io_context — uses an asio timer instead + /// of a dedicated thread for WebSocket stats broadcasting. + /// The caller must ensure the io_context outlives this service and that + /// its run loop is active. + HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true); + ~HttpStatsService(); + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; + // IWebSocketHandler + void OnWebSocketOpen(Ref Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + private: LoggerRef m_Log; HttpRequestRouter m_Router; @@ -28,6 +55,22 @@ private: RwLock m_Lock; std::map m_Providers; + + // WebSocket push + RwLock m_WsConnectionsLock; + std::vector> m_WsConnections; + std::atomic m_PushEnabled{false}; + + void BroadcastStats(); + + // Thread-based push (when no io_context is provided) + std::thread m_PushThread; + Event m_PushEvent; + void PushThreadFunction(); + + // Timer-based push (when an io_context is provided) + std::unique_ptr m_PushTimer; + void EnqueuePushTimer(); }; } // namespace zen diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index b097a0d3f..2370def0c 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -3,15 +3,57 @@ #include "zenhttp/httpstats.h" #include +#include +#include +#include namespace zen { -HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats")) { + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); + } +} + +HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats")) +{ + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushTimer = std::make_unique(IoContext); + EnqueuePushTimer(); + } } HttpStatsService::~HttpStatsService() { + Shutdown(); +} + +void +HttpStatsService::Shutdown() +{ + if (!m_PushEnabled.exchange(false)) + { + return; + } + + if (m_PushTimer) + { + m_PushTimer->cancel(); + m_PushTimer.reset(); + } + + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + m_PushThread.join(); + } + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); } const char* @@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro void HttpStatsService::HandleRequest(HttpServerRequest& Request) { + ZEN_TRACE_CPU("HttpStatsService::HandleRequest"); using namespace std::literals; std::string_view Key = Request.RelativeUri(); @@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpStatsService::OnWebSocketOpen(Ref Connection) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); + ZEN_INFO("Stats WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Send initial state immediately + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + } +} + +void +HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/) +{ + // No client-to-server messages expected +} + +void +HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose"); + ZEN_INFO("Stats WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +////////////////////////////////////////////////////////////////////////// +// +// Stats broadcast +// + +void +HttpStatsService::BroadcastStats() +{ + ZEN_TRACE_CPU("HttpStatsService::BroadcastStats"); + std::vector> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + return; + } + + // Collect stats from all providers + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + + bool First = true; + { + RwLock::SharedLockScope _(m_Lock); + for (auto& [Id, Provider] : m_Providers) + { + CbObject Stats = Provider->CollectStats(); + if (!Stats) + { + continue; + } + + if (!First) + { + JsonBuilder.Append(","); + } + First = false; + + // Emit as "provider_id": { ... } + JsonBuilder.Append("\""); + JsonBuilder.Append(Id); + JsonBuilder.Append("\":"); + + ExtendableStringBuilder<2048> StatsJson; + Stats.ToJson(StatsJson); + JsonBuilder.Append(StatsJson.ToView()); + } + } + + JsonBuilder.Append("}"); + + std::string_view Json = JsonBuilder.ToView(); + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Thread-based push (fallback when no io_context) +// + +void +HttpStatsService::PushThreadFunction() +{ + SetCurrentThreadName("stats_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(5000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + BroadcastStats(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Timer-based push (when io_context is provided) +// + +void +HttpStatsService::EnqueuePushTimer() +{ + if (!m_PushTimer) + { + return; + } + + m_PushTimer->expires_after(std::chrono::seconds(5)); + m_PushTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + BroadcastStats(); + EnqueuePushTimer(); + }); +} + } // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 2cf051d14..f5178ebe8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -531,6 +531,8 @@ public: std::atomic m_TotalBytesReceived{0}; std::atomic m_TotalBytesSent{0}; + + HttpServer* m_HttpServer = nullptr; }; /** @@ -949,6 +951,7 @@ private: void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop); void CloseConnection(); + void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); HttpAsioServerImpl& m_Server; asio::streambuf m_RequestBuffer; @@ -1166,6 +1169,38 @@ HttpServerConnection::CloseConnection() } } +void +HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) +{ + ExtendableStringBuilder<256> ResponseBuilder; + ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; + if (!Headers.empty()) + { + ResponseBuilder << Headers; + } + if (!m_RequestData.IsKeepAlive()) + { + ResponseBuilder << "Connection: close\r\n"; + } + ResponseBuilder << "\r\n"; + if (!Body.empty()) + { + ResponseBuilder << Body; + } + auto ResponseView = ResponseBuilder.ToView(); + IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); + auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); + asio::async_write( + *m_Socket.get(), + Buffer, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + }); +} + void HttpServerConnection::HandleRequest() { @@ -1204,7 +1239,9 @@ HttpServerConnection::HandleRequest() return; } - Ref WsConn(new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler)); + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + Ref WsConn( + new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); Ref WsConnRef(WsConn.Get()); WsHandler->OnWebSocketOpen(std::move(WsConnRef)); @@ -1241,6 +1278,8 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); + m_Server.m_HttpServer->MarkRequest(); + auto RemoteEndpoint = m_Socket->remote_endpoint(); bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); @@ -1378,51 +1417,24 @@ HttpServerConnection::HandleRequest() } } - if (m_RequestData.RequestVerb() == HttpVerb::kHead) + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect(); + if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty())) { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + ExtendableStringBuilder<128> Headers; + Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n"; + SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView()); + } + else if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + SendInlineResponse(RequestNumber, "404 NOT FOUND"sv); } else { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + SendInlineResponse(RequestNumber, + "404 NOT FOUND"sv, + "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv, + "No suitable route found"sv); } } @@ -1448,8 +1460,11 @@ struct HttpAcceptor m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); #else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false)); + // Allow binding to a port in TIME_WAIT so the server can restart immediately + // after a previous instance exits. On Linux this does not allow two processes + // to actively listen on the same port simultaneously. + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true)); #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); @@ -2092,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config) : m_InitialConfig(Config) , m_Impl(std::make_unique()) { + m_Impl->m_HttpServer = this; ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); } diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 021b941bd..4bf8c61bb 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -378,6 +378,8 @@ HttpPluginConnectionHandler::HandleRequest() { ZEN_TRACE_CPU("http_plugin::HandleRequest"); + m_Server->MarkRequest(); + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); const HttpVerb RequestVerb = Request.RequestVerb(); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index cf639c114..dfe6bb6aa 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -451,6 +451,8 @@ public: inline uint16_t GetResponseCode() const { return m_ResponseCode; } inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + private: eastl::fixed_vector m_HttpDataChunks; uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes @@ -460,6 +462,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span Blobs); }; @@ -715,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -916,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1083,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1092,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); return 0; } @@ -1116,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1138,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1173,17 +1195,18 @@ HttpSysServer::InitializeServer(int BasePort) const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1191,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1211,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort) if (m_BaseUris.empty()) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1229,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1241,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1273,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } @@ -1295,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort) ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build - -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) - { - ZEN_DEBUG("HTTP3 is available"); - } - else - { - ZEN_DEBUG("HTTP3 is NOT available"); - } -# endif - return EffectivePort; } @@ -1695,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -2245,8 +2278,12 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT if (SendResult == NO_ERROR) { - Ref WsConn( - new WsHttpSysConnection(RequestQueueHandle, RequestId, *WsHandler, Transaction().Iocp())); + Transaction().Server().OnWebSocketConnectionOpened(); + Ref WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); Ref WsConnRef(WsConn.Get()); WsHandler->OnWebSocketOpen(std::move(WsConnRef)); @@ -2255,7 +2292,7 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return nullptr; } - ZEN_WARN("WebSocket 101 send failed: {}", SendResult); + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); // WebSocket upgrade failed — return nullptr since ServerRequest() // was never populated (no InvokeRequestHandler call) @@ -2330,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 3e31b58bc..b2543277a 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -4,6 +4,7 @@ #include "wsframecodec.h" #include +#include namespace zen::asio_http { @@ -16,15 +17,20 @@ WsLog() ////////////////////////////////////////////////////////////////////////// -WsAsioConnection::WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler) +WsAsioConnection::WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler, HttpServer* Server) : m_Socket(std::move(Socket)) , m_Handler(Handler) +, m_HttpServer(Server) { } WsAsioConnection::~WsAsioConnection() { m_IsOpen.store(false); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } } void @@ -101,6 +107,11 @@ WsAsioConnection::ProcessReceivedData() m_ReadBuffer.consume(Frame.BytesConsumed); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + switch (Frame.Opcode) { case WebSocketOpcode::kText: @@ -219,6 +230,11 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) void WsAsioConnection::EnqueueWrite(std::vector Frame) { + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + bool ShouldFlush = false; m_WriteLock.WithExclusiveLock([&] { diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h index d8ffdc00a..e8bb3b1d2 100644 --- a/src/zenhttp/servers/wsasio.h +++ b/src/zenhttp/servers/wsasio.h @@ -14,6 +14,10 @@ ZEN_THIRD_PARTY_INCLUDES_END #include #include +namespace zen { +class HttpServer; +} // namespace zen + namespace zen::asio_http { /** @@ -27,10 +31,11 @@ namespace zen::asio_http { * connection alive for the duration of the async operation. The service layer * also holds a Ref. */ + class WsAsioConnection : public WebSocketConnection { public: - WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler); + WsAsioConnection(std::unique_ptr Socket, IWebSocketHandler& Handler, HttpServer* Server); ~WsAsioConnection() override; /** @@ -58,6 +63,7 @@ private: std::unique_ptr m_Socket; IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; asio::streambuf m_ReadBuffer; RwLock m_WriteLock; diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index 3408b64b3..af320172d 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -7,6 +7,7 @@ # include "wsframecodec.h" # include +# include namespace zen { @@ -19,11 +20,16 @@ WsHttpSysLog() ////////////////////////////////////////////////////////////////////////// -WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp) +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, + HTTP_REQUEST_ID RequestId, + IWebSocketHandler& Handler, + PTP_IO Iocp, + HttpServer* Server) : m_RequestQueueHandle(RequestQueueHandle) , m_RequestId(RequestId) , m_Handler(Handler) , m_Iocp(Iocp) +, m_HttpServer(Server) , m_ReadBuffer(8192) { m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; @@ -40,6 +46,11 @@ WsHttpSysConnection::~WsHttpSysConnection() { Disconnect(); } + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } } void @@ -174,6 +185,11 @@ WsHttpSysConnection::ProcessReceivedData() // Remove consumed bytes m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + switch (Frame.Opcode) { case WebSocketOpcode::kText: @@ -250,6 +266,11 @@ WsHttpSysConnection::ProcessReceivedData() void WsHttpSysConnection::EnqueueWrite(std::vector Frame) { + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + bool ShouldFlush = false; { diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h index d854289e0..6015e3873 100644 --- a/src/zenhttp/servers/wshttpsys.h +++ b/src/zenhttp/servers/wshttpsys.h @@ -19,6 +19,8 @@ namespace zen { +class HttpServer; + /** * WebSocket connection over an http.sys opaque-mode connection * @@ -37,7 +39,7 @@ namespace zen { class WsHttpSysConnection : public WebSocketConnection { public: - WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp); + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server); ~WsHttpSysConnection() override; /** @@ -75,6 +77,7 @@ private: HTTP_REQUEST_ID m_RequestId; IWebSocketHandler& m_Handler; PTP_IO m_Iocp; + HttpServer* m_HttpServer; // Tagged OVERLAPPED contexts for concurrent read and write HttpSysIoContext m_ReadIoContext{}; diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index fd023c490..2134e4ff1 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -767,13 +767,12 @@ namespace { void OnWsMessage(const WebSocketMessage& Msg) override { - m_MessageCount.fetch_add(1); - if (Msg.Opcode == WebSocketOpcode::kText) { std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); m_LastMessage = std::string(Text); } + m_MessageCount.fetch_add(1); } void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override diff --git a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h index bb41f9efc..caf7ecd28 100644 --- a/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h +++ b/src/zenremotestore/include/zenremotestore/jupiter/jupiterhost.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include #include diff --git a/src/zenremotestore/projectstore/remoteprojectstore.cpp b/src/zenremotestore/projectstore/remoteprojectstore.cpp index d5c6286a8..4796b3f2a 100644 --- a/src/zenremotestore/projectstore/remoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/remoteprojectstore.cpp @@ -2447,14 +2447,13 @@ BuildContainer(CidStore& ChunkStore, AsyncOnBlock, RemoteResult); ComposedBlocks++; + // Worker will set Blocks[BlockIndex] = Block (including ChunkRawHashes) under shared lock } else { ZEN_INFO("Bulk group {} attachments", ChunkCount); OnBlockChunks(std::move(ChunksInBlock)); - } - { - // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + // We can share the lock as we are not resizing the vector and only touch our own index RwLock::SharedLockScope _(BlocksLock); Blocks[BlockIndex].ChunkRawHashes = std::move(ChunkRawHashes); } diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 802d06caf..c64f081b3 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -419,6 +419,8 @@ ZenComputeServer::Cleanup() m_IoRunner.join(); } + ShutdownServices(); + if (m_Http) { m_Http->Close(); @@ -570,8 +572,6 @@ ZenComputeServer::RegisterServices(const ZenComputeServerConfig& ServerConfig) ZEN_TRACE_CPU("ZenComputeServer::RegisterServices"); ZEN_UNUSED(ServerConfig); - m_Http->RegisterService(m_StatsService); - if (m_ApiService) { m_Http->RegisterService(*m_ApiService); diff --git a/src/zenserver/compute/computeserver.h b/src/zenserver/compute/computeserver.h index e4a6b01d5..8f4edc0f0 100644 --- a/src/zenserver/compute/computeserver.h +++ b/src/zenserver/compute/computeserver.h @@ -129,7 +129,6 @@ public: void Cleanup(); private: - HttpStatsService m_StatsService; GcManager m_GcManager; GcScheduler m_GcScheduler{m_GcManager}; std::unique_ptr m_CidStore; diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp index 5fa81ff9f..dd4b8956c 100644 --- a/src/zenserver/diag/diagsvcs.cpp +++ b/src/zenserver/diag/diagsvcs.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -51,6 +52,36 @@ HttpHealthService::HttpHealthService() Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; + Writer << "Port"sv << m_HealthInfo.Port; + Writer << "Pid"sv << m_HealthInfo.Pid; + Writer << "IsDedicated"sv << m_HealthInfo.IsDedicated; + Writer << "StartTimeMs"sv << m_HealthInfo.StartTimeMs; + } + + Writer.BeginObject("RuntimeConfig"sv); + for (const auto& Opt : m_HealthInfo.RuntimeConfig) + { + Writer << Opt.first << Opt.second; + } + Writer.EndObject(); + + Writer.BeginObject("BuildConfig"sv); + for (const auto& Opt : m_HealthInfo.BuildOptions) + { + Writer << Opt.first << Opt.second; + } + Writer.EndObject(); + + Writer << "Hostname"sv << GetMachineName(); + Writer << "Platform"sv << GetRuntimePlatformName(); + Writer << "Arch"sv << GetCpuName(); + Writer << "OS"sv << GetOperatingSystemVersion(); + + { + auto Metrics = GetSystemMetrics(); + Writer.BeginObject("System"sv); + Describe(Metrics, Writer); + Writer.EndObject(); } HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h index 8cc869c83..87ce80b3c 100644 --- a/src/zenserver/diag/diagsvcs.h +++ b/src/zenserver/diag/diagsvcs.h @@ -6,6 +6,7 @@ #include #include +#include ////////////////////////////////////////////////////////////////////////// @@ -89,10 +90,16 @@ private: struct HealthServiceInfo { - std::filesystem::path DataRoot; - std::filesystem::path AbsLogPath; - std::string HttpServerClass; - std::string BuildVersion; + std::filesystem::path DataRoot; + std::filesystem::path AbsLogPath; + std::string HttpServerClass; + std::string BuildVersion; + int Port = 0; + int Pid = 0; + bool IsDedicated = false; + int64_t StartTimeMs = 0; + std::vector> BuildOptions; + std::vector> RuntimeConfig; }; /** Health monitoring endpoint diff --git a/src/zenserver/diag/otlphttp.cpp b/src/zenserver/diag/otlphttp.cpp index 1434c9331..d6e24cbe3 100644 --- a/src/zenserver/diag/otlphttp.cpp +++ b/src/zenserver/diag/otlphttp.cpp @@ -10,11 +10,18 @@ #include #include +#include + #if ZEN_WITH_OTEL namespace zen::logging { ////////////////////////////////////////////////////////////////////////// +// +// Important note: in general we cannot use ZEN_WARN/ZEN_ERROR etc in this +// file as it could cause recursive logging calls when we attempt to log +// errors from the OTLP HTTP client itself. +// OtelHttpProtobufSink::OtelHttpProtobufSink(const std::string_view& Uri) : m_OtelHttp(Uri) { @@ -35,15 +42,45 @@ OtelHttpProtobufSink::~OtelHttpProtobufSink() otel::SetTraceRecorder({}); } +void +OtelHttpProtobufSink::CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept +{ + if (!Result.IsSuccess()) + { + uint32_t PrevFailures = m_ConsecutivePostFailures.fetch_add(1); + if (PrevFailures < kMaxReportedFailures) + { + fprintf(stderr, "OtelHttpProtobufSink: %s\n", Result.ErrorMessage(Endpoint).c_str()); + if (PrevFailures + 1 == kMaxReportedFailures) + { + fprintf(stderr, "OtelHttpProtobufSink: suppressing further export errors\n"); + } + } + } + else + { + m_ConsecutivePostFailures.store(0); + } +} + void OtelHttpProtobufSink::RecordSpans(zen::otel::TraceId Trace, std::span Spans) { - std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans); + try + { + std::string Data = m_Encoder.FormatOtelTrace(Trace, Spans); + + IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; + Payload.SetContentType(ZenContentType::kProtobuf); - IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; - Payload.SetContentType(ZenContentType::kProtobuf); + HttpClient::Response Result = m_OtelHttp.Post("/v1/traces", Payload); - auto Result = m_OtelHttp.Post("/v1/traces", Payload); + CheckPostResult(Result, "POST /v1/traces"); + } + catch (const std::exception& Ex) + { + fprintf(stderr, "OtelHttpProtobufSink: exception exporting traces: %s\n", Ex.what()); + } } void @@ -55,22 +92,20 @@ OtelHttpProtobufSink::TraceRecorder::RecordSpans(zen::otel::TraceId Trace, std:: void OtelHttpProtobufSink::Log(const LogMessage& Msg) { + try { std::string Data = m_Encoder.FormatOtelProtobuf(Msg); IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; Payload.SetContentType(ZenContentType::kProtobuf); - auto Result = m_OtelHttp.Post("/v1/logs", Payload); - } + HttpClient::Response Result = m_OtelHttp.Post("/v1/logs", Payload); + CheckPostResult(Result, "POST /v1/logs"); + } + catch (const std::exception& Ex) { - std::string Data = m_Encoder.FormatOtelMetrics(); - - IoBuffer Payload{IoBuffer::Wrap, Data.data(), Data.size()}; - Payload.SetContentType(ZenContentType::kProtobuf); - - auto Result = m_OtelHttp.Post("/v1/metrics", Payload); + fprintf(stderr, "OtelHttpProtobufSink: exception exporting logs: %s\n", Ex.what()); } } void diff --git a/src/zenserver/diag/otlphttp.h b/src/zenserver/diag/otlphttp.h index 8254af04d..64b3dbc87 100644 --- a/src/zenserver/diag/otlphttp.h +++ b/src/zenserver/diag/otlphttp.h @@ -9,6 +9,8 @@ #include #include +#include + #if ZEN_WITH_OTEL namespace zen::logging { @@ -36,6 +38,7 @@ private: virtual void SetFormatter(std::unique_ptr) override {} void RecordSpans(zen::otel::TraceId Trace, std::span Spans); + void CheckPostResult(const HttpClient::Response& Result, const char* Endpoint) noexcept; // This is just a thin wrapper to call back into the sink while participating in // reference counting from the OTEL trace back-end @@ -53,9 +56,13 @@ private: OtelHttpProtobufSink* m_Sink; }; - HttpClient m_OtelHttp; - OtlpEncoder m_Encoder; - Ref m_TraceRecorder; + static constexpr uint32_t kMaxReportedFailures = 5; + + RwLock m_Lock; + std::atomic m_ConsecutivePostFailures{0}; + HttpClient m_OtelHttp; + OtlpEncoder m_Encoder; + Ref m_TraceRecorder; }; } // namespace zen::logging diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip index c167cc70e..84472ff08 100644 Binary files a/src/zenserver/frontend/html.zip and b/src/zenserver/frontend/html.zip differ diff --git a/src/zenserver/frontend/html/banner.js b/src/zenserver/frontend/html/banner.js new file mode 100644 index 000000000..2e878dedf --- /dev/null +++ b/src/zenserver/frontend/html/banner.js @@ -0,0 +1,338 @@ +/** + * zen-banner.js — Zen dashboard banner Web Component + * + * Usage: + * + * + * + * + * + * + * Attributes: + * variant "full" (default) | "compact" + * cluster-status "nominal" (default) | "degraded" | "offline" + * load 0–100 integer, shown as a percentage (default: hidden) + * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator") + * subtitle text after "ZEN" in the wordmark (default: "COMPUTE") + */ + +class ZenBanner extends HTMLElement { + + static get observedAttributes() { + return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle', 'logo-src']; + } + + attributeChangedCallback() { + if (this.shadowRoot) this._render(); + } + + connectedCallback() { + if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); + this._render(); + } + + // ───────────────────────────────────────────── + // Derived values + // ───────────────────────────────────────────── + + get _variant() { return this.getAttribute('variant') || 'full'; } + get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); } + get _load() { return this.getAttribute('load'); } // null → hidden + get _tagline() { return this.getAttribute('tagline'); } // null → default + get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE" + get _logoSrc() { return this.getAttribute('logo-src'); } // null → inline SVG + + get _statusColor() { + return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8'; + } + + get _statusLabel() { + return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL'; + } + + get _loadColor() { + const v = parseInt(this._load, 10); + if (isNaN(v)) return '#7ecfb8'; + if (v >= 85) return '#c0504d'; + if (v >= 60) return '#d4a84b'; + return '#7ecfb8'; + } + + // ───────────────────────────────────────────── + // Render + // ───────────────────────────────────────────── + + _render() { + const compact = this._variant === 'compact'; + this.shadowRoot.innerHTML = ` + + ${this._html(compact)} + `; + } + + // ───────────────────────────────────────────── + // CSS + // ───────────────────────────────────────────── + + _css(compact) { + const height = compact ? '60px' : '100px'; + const padding = compact ? '0 24px' : '0 32px'; + const gap = compact ? '16px' : '24px'; + const markSize = compact ? '34px' : '52px'; + const divH = compact ? '32px' : '48px'; + const nameSize = compact ? '15px' : '22px'; + const tagSize = compact ? '9px' : '11px'; + const sc = this._statusColor; + const lc = this._loadColor; + + return ` + @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap'); + + *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + + :host { + display: block; + font-family: 'Space Mono', monospace; + } + + .banner { + width: 100%; + height: ${height}; + background: var(--theme_g3, #0b0d10); + border: 1px solid var(--theme_g2, #1e2330); + border-radius: 6px; + display: flex; + align-items: center; + padding: ${padding}; + gap: ${gap}; + position: relative; + overflow: hidden; + text-decoration: none; + color: inherit; + cursor: pointer; + } + + /* scan-line texture */ + .banner::before { + content: ''; + position: absolute; + inset: 0; + background: repeating-linear-gradient( + 0deg, + transparent, transparent 3px, + rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px + ); + pointer-events: none; + } + + /* ambient glow */ + .banner::after { + content: ''; + position: absolute; + right: -60px; + top: 50%; + transform: translateY(-50%); + width: 280px; + height: 280px; + background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%); + pointer-events: none; + } + + .logo-mark { + flex-shrink: 0; + width: ${markSize}; + height: ${markSize}; + } + + .logo-mark svg, .logo-mark img { width: 100%; height: 100%; object-fit: contain; } + + .divider { + width: 1px; + height: ${divH}; + background: linear-gradient(to bottom, transparent, var(--theme_g2, #2a3040), transparent); + flex-shrink: 0; + } + + .text-block { + display: flex; + flex-direction: column; + gap: 4px; + } + + .wordmark { + font-weight: 700; + font-size: ${nameSize}; + letter-spacing: 0.12em; + color: var(--theme_bright, #e8e4dc); + text-transform: uppercase; + line-height: 1; + } + + .wordmark span { color: #7ecfb8; } + + .tagline { + font-family: 'Noto Serif JP', serif; + font-weight: 300; + font-size: ${tagSize}; + letter-spacing: 0.3em; + color: var(--theme_faint, #4a5a68); + text-transform: uppercase; + } + + .spacer { flex: 1; } + + /* ── right-side decorative circuit ── */ + .circuit { flex-shrink: 0; opacity: 0.22; } + + /* ── status cluster ── */ + .status-cluster { + display: flex; + flex-direction: column; + align-items: flex-end; + gap: 6px; + } + + .status-row { + display: flex; + align-items: center; + gap: 8px; + } + + .status-lbl { + font-size: 9px; + letter-spacing: 0.18em; + color: var(--theme_faint, #3a4555); + text-transform: uppercase; + } + + .pill { + display: flex; + align-items: center; + gap: 5px; + border-radius: 20px; + padding: 2px 10px; + font-size: 10px; + letter-spacing: 0.1em; + } + + .pill.cluster { + color: ${sc}; + background: color-mix(in srgb, ${sc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${sc} 28%, transparent); + } + + .pill.load-pill { + color: ${lc}; + background: color-mix(in srgb, ${lc} 8%, transparent); + border: 1px solid color-mix(in srgb, ${lc} 28%, transparent); + } + + .dot { + width: 5px; + height: 5px; + border-radius: 50%; + animation: pulse 2.4s ease-in-out infinite; + } + + .dot.cluster { background: ${sc}; } + .dot.load-dot { background: ${lc}; animation-delay: 0.5s; } + + @keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.25; } + } + `; + } + + // ───────────────────────────────────────────── + // HTML template + // ───────────────────────────────────────────── + + _html(compact) { + const loadAttr = this._load; + const hasCluster = !compact && this.hasAttribute('cluster-status'); + const hasLoad = !compact && loadAttr !== null; + const showRight = hasCluster || hasLoad; + + const circuit = showRight ? ` + + + + + + + + ` : ''; + + const clusterRow = hasCluster ? ` +
+ Cluster +
+
+ ${this._statusLabel} +
+
` : ''; + + const loadRow = hasLoad ? ` +
+ Load +
+
+ ${parseInt(loadAttr, 10)} % +
+
` : ''; + + const rightSide = showRight ? ` + ${circuit} +
+ ${clusterRow} + ${loadRow} +
+ ` : ''; + + return ` + + `; + } + + // ───────────────────────────────────────────── + // SVG logo mark + // ───────────────────────────────────────────── + + _logoMark() { + const src = this._logoSrc; + if (src) { + return `zen`; + } + return ` + + + + + + + + + + + + + + + + + + `; + } +} + +customElements.define('zen-banner', ZenBanner); diff --git a/src/zenserver/frontend/html/compute/banner.js b/src/zenserver/frontend/html/compute/banner.js deleted file mode 100644 index 61c7ce21f..000000000 --- a/src/zenserver/frontend/html/compute/banner.js +++ /dev/null @@ -1,321 +0,0 @@ -/** - * zen-banner.js — Zen Compute dashboard banner Web Component - * - * Usage: - * - * - * - * - * - * - * Attributes: - * variant "full" (default) | "compact" - * cluster-status "nominal" (default) | "degraded" | "offline" - * load 0–100 integer, shown as a percentage (default: hidden) - * tagline custom tagline text (default: "Orchestrator Overview" / "Orchestrator") - * subtitle text after "ZEN" in the wordmark (default: "COMPUTE") - */ - -class ZenBanner extends HTMLElement { - - static get observedAttributes() { - return ['variant', 'cluster-status', 'load', 'tagline', 'subtitle']; - } - - attributeChangedCallback() { - if (this.shadowRoot) this._render(); - } - - connectedCallback() { - if (!this.shadowRoot) this.attachShadow({ mode: 'open' }); - this._render(); - } - - // ───────────────────────────────────────────── - // Derived values - // ───────────────────────────────────────────── - - get _variant() { return this.getAttribute('variant') || 'full'; } - get _status() { return (this.getAttribute('cluster-status') || 'nominal').toLowerCase(); } - get _load() { return this.getAttribute('load'); } // null → hidden - get _tagline() { return this.getAttribute('tagline'); } // null → default - get _subtitle() { return this.getAttribute('subtitle'); } // null → "COMPUTE" - - get _statusColor() { - return { nominal: '#7ecfb8', degraded: '#d4a84b', offline: '#c0504d' }[this._status] ?? '#7ecfb8'; - } - - get _statusLabel() { - return { nominal: 'NOMINAL', degraded: 'DEGRADED', offline: 'OFFLINE' }[this._status] ?? 'NOMINAL'; - } - - get _loadColor() { - const v = parseInt(this._load, 10); - if (isNaN(v)) return '#7ecfb8'; - if (v >= 85) return '#c0504d'; - if (v >= 60) return '#d4a84b'; - return '#7ecfb8'; - } - - // ───────────────────────────────────────────── - // Render - // ───────────────────────────────────────────── - - _render() { - const compact = this._variant === 'compact'; - this.shadowRoot.innerHTML = ` - - ${this._html(compact)} - `; - } - - // ───────────────────────────────────────────── - // CSS - // ───────────────────────────────────────────── - - _css(compact) { - const height = compact ? '60px' : '100px'; - const padding = compact ? '0 24px' : '0 32px'; - const gap = compact ? '16px' : '24px'; - const markSize = compact ? '34px' : '52px'; - const divH = compact ? '32px' : '48px'; - const nameSize = compact ? '15px' : '22px'; - const tagSize = compact ? '9px' : '11px'; - const sc = this._statusColor; - const lc = this._loadColor; - - return ` - @import url('https://fonts.googleapis.com/css2?family=Noto+Serif+JP:wght@300;400&family=Space+Mono:wght@400;700&display=swap'); - - *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } - - :host { - display: block; - font-family: 'Space Mono', monospace; - } - - .banner { - width: 100%; - height: ${height}; - background: #0b0d10; - border: 1px solid #1e2330; - border-radius: 6px; - display: flex; - align-items: center; - padding: ${padding}; - gap: ${gap}; - position: relative; - overflow: hidden; - } - - /* scan-line texture */ - .banner::before { - content: ''; - position: absolute; - inset: 0; - background: repeating-linear-gradient( - 0deg, - transparent, transparent 3px, - rgba(255,255,255,0.012) 3px, rgba(255,255,255,0.012) 4px - ); - pointer-events: none; - } - - /* ambient glow */ - .banner::after { - content: ''; - position: absolute; - right: -60px; - top: 50%; - transform: translateY(-50%); - width: 280px; - height: 280px; - background: radial-gradient(circle, rgba(130,200,180,0.06) 0%, transparent 70%); - pointer-events: none; - } - - .logo-mark { - flex-shrink: 0; - width: ${markSize}; - height: ${markSize}; - } - - .logo-mark svg { width: 100%; height: 100%; } - - .divider { - width: 1px; - height: ${divH}; - background: linear-gradient(to bottom, transparent, #2a3040, transparent); - flex-shrink: 0; - } - - .text-block { - display: flex; - flex-direction: column; - gap: 4px; - } - - .wordmark { - font-weight: 700; - font-size: ${nameSize}; - letter-spacing: 0.12em; - color: #e8e4dc; - text-transform: uppercase; - line-height: 1; - } - - .wordmark span { color: #7ecfb8; } - - .tagline { - font-family: 'Noto Serif JP', serif; - font-weight: 300; - font-size: ${tagSize}; - letter-spacing: 0.3em; - color: #4a5a68; - text-transform: uppercase; - } - - .spacer { flex: 1; } - - /* ── right-side decorative circuit ── */ - .circuit { flex-shrink: 0; opacity: 0.22; } - - /* ── status cluster ── */ - .status-cluster { - display: flex; - flex-direction: column; - align-items: flex-end; - gap: 6px; - } - - .status-row { - display: flex; - align-items: center; - gap: 8px; - } - - .status-lbl { - font-size: 9px; - letter-spacing: 0.18em; - color: #3a4555; - text-transform: uppercase; - } - - .pill { - display: flex; - align-items: center; - gap: 5px; - border-radius: 20px; - padding: 2px 10px; - font-size: 10px; - letter-spacing: 0.1em; - } - - .pill.cluster { - color: ${sc}; - background: color-mix(in srgb, ${sc} 8%, transparent); - border: 1px solid color-mix(in srgb, ${sc} 28%, transparent); - } - - .pill.load-pill { - color: ${lc}; - background: color-mix(in srgb, ${lc} 8%, transparent); - border: 1px solid color-mix(in srgb, ${lc} 28%, transparent); - } - - .dot { - width: 5px; - height: 5px; - border-radius: 50%; - animation: pulse 2.4s ease-in-out infinite; - } - - .dot.cluster { background: ${sc}; } - .dot.load-dot { background: ${lc}; animation-delay: 0.5s; } - - @keyframes pulse { - 0%, 100% { opacity: 1; } - 50% { opacity: 0.25; } - } - `; - } - - // ───────────────────────────────────────────── - // HTML template - // ───────────────────────────────────────────── - - _html(compact) { - const loadAttr = this._load; - const showStatus = !compact; - - const rightSide = showStatus ? ` - - - - - - - - - -
-
- Cluster -
-
- ${this._statusLabel} -
-
- ${loadAttr !== null ? ` -
- Load -
-
- ${parseInt(loadAttr, 10)} % -
-
` : ''} -
- ` : ''; - - return ` - - `; - } - - // ───────────────────────────────────────────── - // SVG logo mark - // ───────────────────────────────────────────── - - _svgMark() { - return ` - - - - - - - - - - - - - - - - - - `; - } -} - -customElements.define('zen-banner', ZenBanner); diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html index 1e101d839..66c20175f 100644 --- a/src/zenserver/frontend/html/compute/compute.html +++ b/src/zenserver/frontend/html/compute/compute.html @@ -5,101 +5,13 @@ Zen Compute Dashboard - - + + + + -
- +
+ + Home Node Orchestrator @@ -369,15 +226,15 @@ -