diff options
| author | zousar <[email protected]> | 2026-02-17 21:17:02 -0700 |
|---|---|---|
| committer | zousar <[email protected]> | 2026-02-17 21:17:02 -0700 |
| commit | 33922d451adc6375d7964bd685916a85086299ef (patch) | |
| tree | bf773d33a2ff147523cab7a063242862ff8bfb6d /src/zenhttp | |
| parent | Dependencies table doesn't reflow the entries page (diff) | |
| parent | add http server root password protection (#757) (diff) | |
| download | zen-33922d451adc6375d7964bd685916a85086299ef.tar.xz zen-33922d451adc6375d7964bd685916a85086299ef.zip | |
Merge branch 'main' into zs/web-ui-improvements
Diffstat (limited to 'src/zenhttp')
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 91 | ||||
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 3 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/cprutils.h | 4 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpserver.h | 24 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/security/passwordsecurity.h | 38 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h | 51 | ||||
| -rw-r--r-- | src/zenhttp/security/passwordsecurity.cpp | 164 | ||||
| -rw-r--r-- | src/zenhttp/security/passwordsecurityfilter.cpp | 56 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 21 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.cpp | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.cpp | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpparser.h | 3 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 18 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 51 |
15 files changed, 362 insertions, 170 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 16729ce38..d3b59df2b 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -25,6 +25,7 @@ # include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zenhttp/security/passwordsecurityfilter.h> # include "servers/httpasio.h" # include "servers/httpsys.h" @@ -662,6 +663,96 @@ TEST_CASE("httpclient.requestfilter") } } +TEST_CASE("httpclient.password") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + SUBCASE("usernamepassword") + { + CbObjectWriter Writer; + { + Writer.BeginObject("basic"); + { + Writer << "username"sv + << "me"; + Writer << "password"sv + << "456123789"; + } + Writer.EndObject(); + Writer << "protect-machine-local-requests" << true; + } + + PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save()); + + PasswordHttpFilter MyFilter(PasswordFilterOptions); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response ForbiddenResponse = Client.Get("/test/yo"); + CHECK(!ForbiddenResponse.IsSuccess()); + CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden); + + HttpClient::Response WithBasicResponse = + Client.Get("/test/yo", + std::pair<std::string, std::string>("Authorization", + fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password))); + CHECK(WithBasicResponse.IsSuccess()); + AsioServer->SetHttpRequestFilter(nullptr); + } + AsioServer->RequestExit(); + } +} void httpclient_forcelink() { diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index d8367fcb2..f2fe4738f 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -1317,7 +1317,8 @@ TEST_CASE("http.common") TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } - virtual bool IsLocalMachineRequest() const override { return false; } + virtual bool IsLocalMachineRequest() const override { return false; } + virtual std::string_view GetAuthorizationHeader() const override { return {}; } virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override { diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index a988346e0..c252a5d99 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -66,10 +66,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Json); } else @@ -82,10 +82,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Body.GetText()); } } diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 60f6bc9f2..350532126 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -39,7 +39,7 @@ public: // Synchronous operations [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } [[nodiscard]] inline HttpService& Service() const { return m_Service; } @@ -81,6 +81,18 @@ public: inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + inline void SetLogRequest(bool ShouldLog) + { + if (ShouldLog) + { + m_Flags |= kLogRequest; + } + else + { + m_Flags &= ~kLogRequest; + } + } + inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); } /** Read POST/PUT payload for request body, which is always available without delay */ @@ -89,7 +101,8 @@ public: CbObject ReadPayloadObject(); CbPackage ReadPayloadPackage(); - virtual bool IsLocalMachineRequest() const = 0; + virtual bool IsLocalMachineRequest() const = 0; + virtual std::string_view GetAuthorizationHeader() const = 0; /** Respond with payload @@ -119,6 +132,7 @@ protected: kSuppressBody = 1 << 1, kHaveRequestId = 1 << 2, kHaveSessionId = 1 << 3, + kLogRequest = 1 << 4, }; mutable uint32_t m_Flags = 0; @@ -149,8 +163,10 @@ public: virtual void OnRequestComplete() = 0; }; -struct IHttpRequestFilter +class IHttpRequestFilter { +public: + virtual ~IHttpRequestFilter() {} enum class Result { Forbidden, @@ -254,7 +270,7 @@ public: inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} ~HttpRouterRequest() = default; HttpRouterRequest(const HttpRouterRequest&) = delete; diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h index 026c2865b..6b2b548a6 100644 --- a/src/zenhttp/include/zenhttp/security/passwordsecurity.h +++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h @@ -10,43 +10,29 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -struct PasswordSecurityConfiguration -{ - std::string Password; // "password" - bool ProtectMachineLocalRequests = false; // "protect-machine-local-requests" - std::vector<std::string> UnprotectedUris; // "unprotected-urls" -}; - class PasswordSecurity { public: - PasswordSecurity(const PasswordSecurityConfiguration& Config); + struct Configuration + { + std::string Password; + bool ProtectMachineLocalRequests = false; + std::vector<std::string> UnprotectedUris; + }; + + explicit PasswordSecurity(const Configuration& Config); [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; } [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; } - [[nodiscard]] bool IsUnprotectedUri(std::string_view Uri) const; + [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const; - bool IsAllowed(std::string_view Password, std::string_view Uri, bool IsMachineLocalRequest); + bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest); private: - const PasswordSecurityConfiguration m_Config; - tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUrlHashes; + const Configuration m_Config; + tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes; }; -/** - * Expected format (Json) - * { - * "password\": \"1234\", - * "protect-machine-local-requests\": false, - * "unprotected-urls\": [ - * "/health\", - * "/health/info\", - * "/health/version\" - * ] - * } - */ -PasswordSecurityConfiguration ReadPasswordSecurityConfiguration(CbObjectView ConfigObject); - void passwordsecurity_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h new file mode 100644 index 000000000..c098f05ad --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenhttp/security/passwordsecurity.h> + +namespace zen { + +class PasswordHttpFilter : public IHttpRequestFilter +{ +public: + static constexpr std::string_view TypeName = "password"; + + struct Configuration + { + PasswordSecurity::Configuration PasswordConfig; + std::string AuthenticationTypeString; + }; + + /** + * Expected format (Json) + * { + * "password": { # "Authorization: Basic <username:password base64 encoded>" style + * "username": "<username>", + * "password": "<password>" + * }, + * "protect-machine-local-requests": false, + * "unprotected-uris": [ + * "/health/", + * "/health/info", + * "/health/version" + * ] + * } + */ + static Configuration ReadConfiguration(CbObjectView Config); + + explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config) + : m_PasswordSecurity(Config.PasswordConfig) + , m_AuthenticationTypeString(Config.AuthenticationTypeString) + { + } + + virtual Result FilterRequest(HttpServerRequest& Request) override; + +private: + PasswordSecurity m_PasswordSecurity; + const std::string m_AuthenticationTypeString; +}; + +} // namespace zen diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp index 37be9a018..a8fb9c3f5 100644 --- a/src/zenhttp/security/passwordsecurity.cpp +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -13,13 +13,13 @@ namespace zen { using namespace std::literals; -PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) : m_Config(Config) +PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config) { - m_UnprotectedUrlHashes.reserve(m_Config.UnprotectedUris.size()); + m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size()); for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) { const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; - if (auto Result = m_UnprotectedUrlHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) + if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) { throw std::runtime_error(fmt::format( "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", @@ -32,35 +32,30 @@ PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) } bool -PasswordSecurity::IsUnprotectedUri(std::string_view Uri) const +PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const { if (!m_Config.UnprotectedUris.empty()) { - uint32_t UriHash = HashStringDjb2(Uri); - if (auto It = m_UnprotectedUrlHashes.find(UriHash); It != m_UnprotectedUrlHashes.end()) + uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri}); + if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end()) { - if (m_Config.UnprotectedUris[It->second] == Uri) + const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second]; + if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length()) { - return true; + if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri) + { + return true; + } } } } return false; } -PasswordSecurityConfiguration -ReadPasswordSecurityConfiguration(CbObjectView ConfigObject) -{ - return PasswordSecurityConfiguration{ - .Password = std::string(ConfigObject["password"sv].AsString()), - .ProtectMachineLocalRequests = ConfigObject["protect-machine-local-requests"sv].AsBool(), - .UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-urls"sv, ConfigObject)}; -} - bool -PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, bool IsMachineLocalRequest) +PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest) { - if (IsUnprotectedUri(Uri)) + if (IsUnprotectedUri(BaseUri, RelativeUri)) { return true; } @@ -81,119 +76,74 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, b #if ZEN_WITH_TESTS -TEST_CASE("passwordsecurity.readconfig") -{ - auto ReadConfigJson = [](std::string_view Json) { - std::string JsonError; - CbObject Config = LoadCompactBinaryFromJson(Json, JsonError).AsObject(); - REQUIRE(JsonError.empty()); - return Config; - }; - - { - PasswordSecurityConfiguration EmptyConfig = ReadPasswordSecurityConfiguration(CbObject()); - CHECK(EmptyConfig.Password.empty()); - CHECK(!EmptyConfig.ProtectMachineLocalRequests); - CHECK(EmptyConfig.UnprotectedUris.empty()); - } - - { - const std::string_view SimpleConfigJson = - "{\n" - " \"password\": \"1234\"\n" - "}"; - PasswordSecurityConfiguration SimpleConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(SimpleConfigJson)); - CHECK(SimpleConfig.Password == "1234"); - CHECK(!SimpleConfig.ProtectMachineLocalRequests); - CHECK(SimpleConfig.UnprotectedUris.empty()); - } - - { - const std::string_view ComplexConfigJson = - "{\n" - " \"password\": \"1234\",\n" - " \"protect-machine-local-requests\": true,\n" - " \"unprotected-urls\": [\n" - " \"/health\",\n" - " \"/health/info\",\n" - " \"/health/version\"\n" - " ]\n" - "}"; - PasswordSecurityConfiguration ComplexConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(ComplexConfigJson)); - CHECK(ComplexConfig.Password == "1234"); - CHECK(ComplexConfig.ProtectMachineLocalRequests); - CHECK(ComplexConfig.UnprotectedUris == std::vector<std::string>({"/health", "/health/info", "/health/version"})); - } -} - TEST_CASE("passwordsecurity.allowanything") { PasswordSecurity Anything({}); - CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowalllocal") { PasswordSecurity AllLocal({.Password = "123456"}); - CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); } TEST_CASE("passwordsecurity.allowonlypassword") { PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomeexternaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); - CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.allowsomelocaluris") { PasswordSecurity AllLocal( {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true)); - CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true)); - CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false)); - CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } TEST_CASE("passwordsecurity.conflictingunprotecteduris") diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp new file mode 100644 index 000000000..87d8cc275 --- /dev/null +++ b/src/zenhttp/security/passwordsecurityfilter.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurityfilter.h" + +#include <zencore/base64.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> + +namespace zen { + +using namespace std::literals; + +PasswordHttpFilter::Configuration +PasswordHttpFilter::ReadConfiguration(CbObjectView Config) +{ + Configuration Result; + if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType) + { + Result.AuthenticationTypeString = "Basic "; + std::string_view Username = PasswordType["username"sv].AsString(); + std::string_view Password = PasswordType["password"sv].AsString(); + std::string UsernamePassword = fmt::format("{}:{}", Username, Password); + Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length()))); + Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()), + uint32_t(UsernamePassword.size()), + const_cast<char*>(Result.PasswordConfig.Password.data())); + } + Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool(); + Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config); + return Result; +} + +IHttpRequestFilter::Result +PasswordHttpFilter::FilterRequest(HttpServerRequest& Request) +{ + std::string_view Password; + std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); + size_t AuthorizationHeaderLength = AuthorizationHeader.length(); + if (AuthorizationHeaderLength > m_AuthenticationTypeString.length()) + { + if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) + { + Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + } + } + + bool IsAllowed = + m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest()); + if (IsAllowed) + { + return Result::Accepted; + } + return Result::Forbidden; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 230aac6a8..1c0ebef90 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -147,7 +147,7 @@ inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); + // Logger.SetLogLevel(logging::level::Trace); return Logger; } @@ -542,7 +542,8 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -1264,6 +1265,11 @@ HttpServerConnection::HandleRequest() if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) { + if (Request.ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength())); + } + // Transmit the response if (m_RequestData.RequestVerb() == HttpVerb::kHead) @@ -1742,6 +1748,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const return m_IsLocalMachineRequest; } +std::string_view +HttpAsioServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + IoBuffer HttpAsioServerRequest::ReadPayload() { @@ -1959,8 +1971,8 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + + return RequestFilter->FilterRequest(Request); } } // namespace zen::asio_http @@ -2075,6 +2087,7 @@ HttpAsioServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 850d7d6b9..310ac9dc0 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -82,6 +82,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index db360c5fb..9bb7ef3bc 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -57,6 +57,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 93094e21b..be5befcd2 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -19,6 +19,7 @@ static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); ////////////////////////////////////////////////////////////////////////// // @@ -154,6 +155,10 @@ HttpRequestParser::ParseCurrentHeader() { m_ContentTypeHeaderIndex = CurrentHeaderIndex; } + else if (HeaderHash == HashAuthorization) + { + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + } else if (HeaderHash == HashSession) { m_SessionId = Oid::TryFromHexString(HeaderValue); @@ -357,6 +362,7 @@ HttpRequestParser::ResetState() m_AcceptHeaderIndex = -1; m_ContentTypeHeaderIndex = -1; m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 0d2664ec5..ff56ca970 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -46,6 +46,8 @@ struct HttpRequestParser std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); } + std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + private: struct HeaderRange { @@ -83,6 +85,7 @@ private: int8_t m_AcceptHeaderIndex; int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; + int8_t m_AuthorizationHeaderIndex; HttpVerb m_RequestVerb; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 4219dc292..8564826d6 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -147,10 +147,10 @@ public: HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection - virtual bool IsLocalMachineRequest() const /* override*/ { return false; } - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual std::string_view GetAuthorizationHeader() const override; + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -636,6 +636,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest() { } +std::string_view +HttpPluginServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + Oid HttpPluginServerRequest::ParseSessionId() const { @@ -831,6 +837,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -932,8 +939,7 @@ HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + return RequestFilter->FilterRequest(Request); } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4df4cd079..14896c803 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -72,6 +72,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In OutString.Append("unknown"); } +class HttpSysServerRequest; + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -102,7 +104,7 @@ public: inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } - IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); private: int InitializeServer(int BasePort); @@ -319,7 +321,8 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const; + virtual std::string_view GetAuthorizationHeader() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -702,21 +705,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) HTTP_CACHE_POLICY CachePolicy; - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags + &HttpResponse, // HttpResponse + &CachePolicy, // CachePolicy + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); m_IsInitialResponse = false; } @@ -724,9 +728,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { // Subsequent response API calls - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent @@ -1351,7 +1355,6 @@ HttpSysServer::OnRun(bool IsInteractive) bool ShutdownRequested = false; do { - // int WaitTimeout = -1; int WaitTimeout = 100; if (IsInteractive) @@ -1364,6 +1367,7 @@ HttpSysServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -1861,6 +1865,14 @@ HttpSysServerRequest::IsLocalMachineRequest() const } } +std::string_view +HttpSysServerRequest::GetAuthorizationHeader() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; + return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); +} + IoBuffer HttpSysServerRequest::ReadPayload() { @@ -2270,7 +2282,7 @@ HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) } IHttpRequestFilter::Result -HttpSysServer::FilterRequest(HttpServerRequest& Request) +HttpSysServer::FilterRequest(HttpSysServerRequest& Request) { if (!m_HttpRequestFilter.load()) { @@ -2282,8 +2294,7 @@ HttpSysServer::FilterRequest(HttpServerRequest& Request) { return IHttpRequestFilter::Result::Accepted; } - IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); - return FilterResult; + return RequestFilter->FilterRequest(Request); } Ref<HttpServer> |