aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
authorDan Engelbrecht <[email protected]>2026-02-17 14:00:53 +0100
committerGitHub Enterprise <[email protected]>2026-02-17 14:00:53 +0100
commit5e1e23e209eec75a396c18f8eee3d93a9e196bfc (patch)
tree31b2b3938468aacdb0621e8b932cb9e9738ee918 /src/zenhttp
parentmisc fixes brought over from sb/proto (#759) (diff)
downloadzen-5e1e23e209eec75a396c18f8eee3d93a9e196bfc.tar.xz
zen-5e1e23e209eec75a396c18f8eee3d93a9e196bfc.zip
add http server root password protection (#757)
- Feature: Added `--security-config-path` option to zenserver to configure security settings - Expects a path to a .json file - Default is an empty path resulting in no extra security settings and legacy behavior - Current support is a top level filter of incoming http requests restricted to the `password` type - `password` type will check the `Authorization` header and match it to the selected authorization strategy - Currently the security settings is very basic and configured to a fixed username+password at startup { "http" { "root": { "filter": { "type": "password", "config": { "password": { "username": "<username>", "password": "<password>" }, "protect-machine-local-requests": false, "unprotected-uris": [ "/health/", "/health/info", "/health/version" ] } } } } }
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/httpclient.cpp91
-rw-r--r--src/zenhttp/httpserver.cpp3
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h7
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurity.h38
-rw-r--r--src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h51
-rw-r--r--src/zenhttp/security/passwordsecurity.cpp164
-rw-r--r--src/zenhttp/security/passwordsecurityfilter.cpp56
-rw-r--r--src/zenhttp/servers/httpasio.cpp14
-rw-r--r--src/zenhttp/servers/httpmulti.cpp1
-rw-r--r--src/zenhttp/servers/httpnull.cpp1
-rw-r--r--src/zenhttp/servers/httpparser.cpp6
-rw-r--r--src/zenhttp/servers/httpparser.h3
-rw-r--r--src/zenhttp/servers/httpplugin.cpp18
-rw-r--r--src/zenhttp/servers/httpsys.cpp21
14 files changed, 324 insertions, 150 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 16729ce38..d3b59df2b 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -25,6 +25,7 @@
# include <zencore/scopeguard.h>
# include <zencore/testing.h>
# include <zencore/testutils.h>
+# include <zenhttp/security/passwordsecurityfilter.h>
# include "servers/httpasio.h"
# include "servers/httpsys.h"
@@ -662,6 +663,96 @@ TEST_CASE("httpclient.requestfilter")
}
}
+TEST_CASE("httpclient.password")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ SUBCASE("usernamepassword")
+ {
+ CbObjectWriter Writer;
+ {
+ Writer.BeginObject("basic");
+ {
+ Writer << "username"sv
+ << "me";
+ Writer << "password"sv
+ << "456123789";
+ }
+ Writer.EndObject();
+ Writer << "protect-machine-local-requests" << true;
+ }
+
+ PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save());
+
+ PasswordHttpFilter MyFilter(PasswordFilterOptions);
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response ForbiddenResponse = Client.Get("/test/yo");
+ CHECK(!ForbiddenResponse.IsSuccess());
+ CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ HttpClient::Response WithBasicResponse =
+ Client.Get("/test/yo",
+ std::pair<std::string, std::string>("Authorization",
+ fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password)));
+ CHECK(WithBasicResponse.IsSuccess());
+ AsioServer->SetHttpRequestFilter(nullptr);
+ }
+ AsioServer->RequestExit();
+ }
+}
void
httpclient_forcelink()
{
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index d8367fcb2..f2fe4738f 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -1317,7 +1317,8 @@ TEST_CASE("http.common")
TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; }
virtual IoBuffer ReadPayload() override { return IoBuffer(); }
- virtual bool IsLocalMachineRequest() const override { return false; }
+ virtual bool IsLocalMachineRequest() const override { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override { return {}; }
virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override
{
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index cbac06cb6..350532126 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -101,7 +101,8 @@ public:
CbObject ReadPayloadObject();
CbPackage ReadPayloadPackage();
- virtual bool IsLocalMachineRequest() const = 0;
+ virtual bool IsLocalMachineRequest() const = 0;
+ virtual std::string_view GetAuthorizationHeader() const = 0;
/** Respond with payload
@@ -162,8 +163,10 @@ public:
virtual void OnRequestComplete() = 0;
};
-struct IHttpRequestFilter
+class IHttpRequestFilter
{
+public:
+ virtual ~IHttpRequestFilter() {}
enum class Result
{
Forbidden,
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
index 026c2865b..6b2b548a6 100644
--- a/src/zenhttp/include/zenhttp/security/passwordsecurity.h
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h
@@ -10,43 +10,29 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-struct PasswordSecurityConfiguration
-{
- std::string Password; // "password"
- bool ProtectMachineLocalRequests = false; // "protect-machine-local-requests"
- std::vector<std::string> UnprotectedUris; // "unprotected-urls"
-};
-
class PasswordSecurity
{
public:
- PasswordSecurity(const PasswordSecurityConfiguration& Config);
+ struct Configuration
+ {
+ std::string Password;
+ bool ProtectMachineLocalRequests = false;
+ std::vector<std::string> UnprotectedUris;
+ };
+
+ explicit PasswordSecurity(const Configuration& Config);
[[nodiscard]] inline std::string_view Password() const { return m_Config.Password; }
[[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; }
- [[nodiscard]] bool IsUnprotectedUri(std::string_view Uri) const;
+ [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const;
- bool IsAllowed(std::string_view Password, std::string_view Uri, bool IsMachineLocalRequest);
+ bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest);
private:
- const PasswordSecurityConfiguration m_Config;
- tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUrlHashes;
+ const Configuration m_Config;
+ tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes;
};
-/**
- * Expected format (Json)
- * {
- * "password\": \"1234\",
- * "protect-machine-local-requests\": false,
- * "unprotected-urls\": [
- * "/health\",
- * "/health/info\",
- * "/health/version\"
- * ]
- * }
- */
-PasswordSecurityConfiguration ReadPasswordSecurityConfiguration(CbObjectView ConfigObject);
-
void passwordsecurity_forcelink(); // internal
} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
new file mode 100644
index 000000000..c098f05ad
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h
@@ -0,0 +1,51 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <zenhttp/security/passwordsecurity.h>
+
+namespace zen {
+
+class PasswordHttpFilter : public IHttpRequestFilter
+{
+public:
+ static constexpr std::string_view TypeName = "password";
+
+ struct Configuration
+ {
+ PasswordSecurity::Configuration PasswordConfig;
+ std::string AuthenticationTypeString;
+ };
+
+ /**
+ * Expected format (Json)
+ * {
+ * "password": { # "Authorization: Basic <username:password base64 encoded>" style
+ * "username": "<username>",
+ * "password": "<password>"
+ * },
+ * "protect-machine-local-requests": false,
+ * "unprotected-uris": [
+ * "/health/",
+ * "/health/info",
+ * "/health/version"
+ * ]
+ * }
+ */
+ static Configuration ReadConfiguration(CbObjectView Config);
+
+ explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config)
+ : m_PasswordSecurity(Config.PasswordConfig)
+ , m_AuthenticationTypeString(Config.AuthenticationTypeString)
+ {
+ }
+
+ virtual Result FilterRequest(HttpServerRequest& Request) override;
+
+private:
+ PasswordSecurity m_PasswordSecurity;
+ const std::string m_AuthenticationTypeString;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp
index 37be9a018..a8fb9c3f5 100644
--- a/src/zenhttp/security/passwordsecurity.cpp
+++ b/src/zenhttp/security/passwordsecurity.cpp
@@ -13,13 +13,13 @@
namespace zen {
using namespace std::literals;
-PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config) : m_Config(Config)
+PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config)
{
- m_UnprotectedUrlHashes.reserve(m_Config.UnprotectedUris.size());
+ m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size());
for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++)
{
const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index];
- if (auto Result = m_UnprotectedUrlHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second)
+ if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second)
{
throw std::runtime_error(fmt::format(
"password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')",
@@ -32,35 +32,30 @@ PasswordSecurity::PasswordSecurity(const PasswordSecurityConfiguration& Config)
}
bool
-PasswordSecurity::IsUnprotectedUri(std::string_view Uri) const
+PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const
{
if (!m_Config.UnprotectedUris.empty())
{
- uint32_t UriHash = HashStringDjb2(Uri);
- if (auto It = m_UnprotectedUrlHashes.find(UriHash); It != m_UnprotectedUrlHashes.end())
+ uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri});
+ if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end())
{
- if (m_Config.UnprotectedUris[It->second] == Uri)
+ const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second];
+ if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length())
{
- return true;
+ if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri)
+ {
+ return true;
+ }
}
}
}
return false;
}
-PasswordSecurityConfiguration
-ReadPasswordSecurityConfiguration(CbObjectView ConfigObject)
-{
- return PasswordSecurityConfiguration{
- .Password = std::string(ConfigObject["password"sv].AsString()),
- .ProtectMachineLocalRequests = ConfigObject["protect-machine-local-requests"sv].AsBool(),
- .UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-urls"sv, ConfigObject)};
-}
-
bool
-PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, bool IsMachineLocalRequest)
+PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest)
{
- if (IsUnprotectedUri(Uri))
+ if (IsUnprotectedUri(BaseUri, RelativeUri))
{
return true;
}
@@ -81,119 +76,74 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view Uri, b
#if ZEN_WITH_TESTS
-TEST_CASE("passwordsecurity.readconfig")
-{
- auto ReadConfigJson = [](std::string_view Json) {
- std::string JsonError;
- CbObject Config = LoadCompactBinaryFromJson(Json, JsonError).AsObject();
- REQUIRE(JsonError.empty());
- return Config;
- };
-
- {
- PasswordSecurityConfiguration EmptyConfig = ReadPasswordSecurityConfiguration(CbObject());
- CHECK(EmptyConfig.Password.empty());
- CHECK(!EmptyConfig.ProtectMachineLocalRequests);
- CHECK(EmptyConfig.UnprotectedUris.empty());
- }
-
- {
- const std::string_view SimpleConfigJson =
- "{\n"
- " \"password\": \"1234\"\n"
- "}";
- PasswordSecurityConfiguration SimpleConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(SimpleConfigJson));
- CHECK(SimpleConfig.Password == "1234");
- CHECK(!SimpleConfig.ProtectMachineLocalRequests);
- CHECK(SimpleConfig.UnprotectedUris.empty());
- }
-
- {
- const std::string_view ComplexConfigJson =
- "{\n"
- " \"password\": \"1234\",\n"
- " \"protect-machine-local-requests\": true,\n"
- " \"unprotected-urls\": [\n"
- " \"/health\",\n"
- " \"/health/info\",\n"
- " \"/health/version\"\n"
- " ]\n"
- "}";
- PasswordSecurityConfiguration ComplexConfig = ReadPasswordSecurityConfiguration(ReadConfigJson(ComplexConfigJson));
- CHECK(ComplexConfig.Password == "1234");
- CHECK(ComplexConfig.ProtectMachineLocalRequests);
- CHECK(ComplexConfig.UnprotectedUris == std::vector<std::string>({"/health", "/health/info", "/health/version"}));
- }
-}
-
TEST_CASE("passwordsecurity.allowanything")
{
PasswordSecurity Anything({});
- CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(Anything.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
}
TEST_CASE("passwordsecurity.allowalllocal")
{
PasswordSecurity AllLocal({.Password = "123456"});
- CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
}
TEST_CASE("passwordsecurity.allowonlypassword")
{
PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true});
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
}
TEST_CASE("passwordsecurity.allowsomeexternaluris")
{
PasswordSecurity AllLocal(
{.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
- CHECK(AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
}
TEST_CASE("passwordsecurity.allowsomelocaluris")
{
PasswordSecurity AllLocal(
{.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})});
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed(""sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ true));
- CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ true));
- CHECK(AllLocal.IsAllowed(""sv, "/free/access", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed(""sv, "/ok", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free/access", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", /*IsMachineLocalRequest*/ false));
- CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true));
+ CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true));
+ CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false));
+ CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
}
TEST_CASE("passwordsecurity.conflictingunprotecteduris")
diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp
new file mode 100644
index 000000000..87d8cc275
--- /dev/null
+++ b/src/zenhttp/security/passwordsecurityfilter.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenhttp/security/passwordsecurityfilter.h"
+
+#include <zencore/base64.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/fmtutils.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+PasswordHttpFilter::Configuration
+PasswordHttpFilter::ReadConfiguration(CbObjectView Config)
+{
+ Configuration Result;
+ if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType)
+ {
+ Result.AuthenticationTypeString = "Basic ";
+ std::string_view Username = PasswordType["username"sv].AsString();
+ std::string_view Password = PasswordType["password"sv].AsString();
+ std::string UsernamePassword = fmt::format("{}:{}", Username, Password);
+ Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length())));
+ Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()),
+ uint32_t(UsernamePassword.size()),
+ const_cast<char*>(Result.PasswordConfig.Password.data()));
+ }
+ Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool();
+ Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config);
+ return Result;
+}
+
+IHttpRequestFilter::Result
+PasswordHttpFilter::FilterRequest(HttpServerRequest& Request)
+{
+ std::string_view Password;
+ std::string_view AuthorizationHeader = Request.GetAuthorizationHeader();
+ size_t AuthorizationHeaderLength = AuthorizationHeader.length();
+ if (AuthorizationHeaderLength > m_AuthenticationTypeString.length())
+ {
+ if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0)
+ {
+ Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length());
+ }
+ }
+
+ bool IsAllowed =
+ m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest());
+ if (IsAllowed)
+ {
+ return Result::Accepted;
+ }
+ return Result::Forbidden;
+}
+
+} // namespace zen
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 1f42b05d2..1c0ebef90 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -542,7 +542,8 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
- virtual bool IsLocalMachineRequest() const override;
+ virtual bool IsLocalMachineRequest() const override;
+ virtual std::string_view GetAuthorizationHeader() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -1747,6 +1748,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const
return m_IsLocalMachineRequest;
}
+std::string_view
+HttpAsioServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
IoBuffer
HttpAsioServerRequest::ReadPayload()
{
@@ -1964,8 +1971,8 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
{
return IHttpRequestFilter::Result::Accepted;
}
- IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
- return FilterResult;
+
+ return RequestFilter->FilterRequest(Request);
}
} // namespace zen::asio_http
@@ -2080,6 +2087,7 @@ HttpAsioServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 850d7d6b9..310ac9dc0 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -82,6 +82,7 @@ HttpMultiServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index db360c5fb..9bb7ef3bc 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -57,6 +57,7 @@ HttpNullServer::OnRun(bool IsInteractiveSession)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp
index 93094e21b..be5befcd2 100644
--- a/src/zenhttp/servers/httpparser.cpp
+++ b/src/zenhttp/servers/httpparser.cpp
@@ -19,6 +19,7 @@ static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv);
//////////////////////////////////////////////////////////////////////////
//
@@ -154,6 +155,10 @@ HttpRequestParser::ParseCurrentHeader()
{
m_ContentTypeHeaderIndex = CurrentHeaderIndex;
}
+ else if (HeaderHash == HashAuthorization)
+ {
+ m_AuthorizationHeaderIndex = CurrentHeaderIndex;
+ }
else if (HeaderHash == HashSession)
{
m_SessionId = Oid::TryFromHexString(HeaderValue);
@@ -357,6 +362,7 @@ HttpRequestParser::ResetState()
m_AcceptHeaderIndex = -1;
m_ContentTypeHeaderIndex = -1;
m_RangeHeaderIndex = -1;
+ m_AuthorizationHeaderIndex = -1;
m_Expect100Continue = false;
m_BodyBuffer = {};
m_BodyPosition = 0;
diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h
index 0d2664ec5..ff56ca970 100644
--- a/src/zenhttp/servers/httpparser.h
+++ b/src/zenhttp/servers/httpparser.h
@@ -46,6 +46,8 @@ struct HttpRequestParser
std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); }
+ std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); }
+
private:
struct HeaderRange
{
@@ -83,6 +85,7 @@ private:
int8_t m_AcceptHeaderIndex;
int8_t m_ContentTypeHeaderIndex;
int8_t m_RangeHeaderIndex;
+ int8_t m_AuthorizationHeaderIndex;
HttpVerb m_RequestVerb;
std::atomic_bool m_KeepAlive{false};
bool m_Expect100Continue = false;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 4219dc292..8564826d6 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -147,10 +147,10 @@ public:
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
// As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
- virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
-
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
+ virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -636,6 +636,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest()
{
}
+std::string_view
+HttpPluginServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
Oid
HttpPluginServerRequest::ParseSessionId() const
{
@@ -831,6 +837,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -932,8 +939,7 @@ HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
{
return IHttpRequestFilter::Result::Accepted;
}
- IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
- return FilterResult;
+ return RequestFilter->FilterRequest(Request);
}
//////////////////////////////////////////////////////////////////////////
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 5fed94f1c..14896c803 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -72,6 +72,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In
OutString.Append("unknown");
}
+class HttpSysServerRequest;
+
/**
* @brief Windows implementation of HTTP server based on http.sys
*
@@ -102,7 +104,7 @@ public:
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
- IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
+ IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request);
private:
int InitializeServer(int BasePort);
@@ -319,7 +321,8 @@ public:
virtual Oid ParseSessionId() const override;
virtual uint32_t ParseRequestId() const override;
- virtual bool IsLocalMachineRequest() const;
+ virtual bool IsLocalMachineRequest() const;
+ virtual std::string_view GetAuthorizationHeader() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -1364,6 +1367,7 @@ HttpSysServer::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -1861,6 +1865,14 @@ HttpSysServerRequest::IsLocalMachineRequest() const
}
}
+std::string_view
+HttpSysServerRequest::GetAuthorizationHeader() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization];
+ return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength);
+}
+
IoBuffer
HttpSysServerRequest::ReadPayload()
{
@@ -2270,7 +2282,7 @@ HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
}
IHttpRequestFilter::Result
-HttpSysServer::FilterRequest(HttpServerRequest& Request)
+HttpSysServer::FilterRequest(HttpSysServerRequest& Request)
{
if (!m_HttpRequestFilter.load())
{
@@ -2282,8 +2294,7 @@ HttpSysServer::FilterRequest(HttpServerRequest& Request)
{
return IHttpRequestFilter::Result::Accepted;
}
- IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
- return FilterResult;
+ return RequestFilter->FilterRequest(Request);
}
Ref<HttpServer>